diff --git a/src/modules/complianceengine/src/lib/CMakeLists.txt b/src/modules/complianceengine/src/lib/CMakeLists.txt index 9f95044955..8640d0b582 100644 --- a/src/modules/complianceengine/src/lib/CMakeLists.txt +++ b/src/modules/complianceengine/src/lib/CMakeLists.txt @@ -138,6 +138,7 @@ add_library(complianceenginelib STATIC DistributionInfo.cpp Engine.cpp Evaluator.cpp + InputStream.cpp FileTreeWalk.cpp FilesystemScanner.cpp GroupsIterator.cpp diff --git a/src/modules/complianceengine/src/lib/DistributionInfo.cpp b/src/modules/complianceengine/src/lib/DistributionInfo.cpp index 4ad634cfce..f074533e8e 100644 --- a/src/modules/complianceengine/src/lib/DistributionInfo.cpp +++ b/src/modules/complianceengine/src/lib/DistributionInfo.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include diff --git a/src/modules/complianceengine/src/lib/Evaluator.h b/src/modules/complianceengine/src/lib/Evaluator.h index 4bf8db09c2..95b5be7f96 100644 --- a/src/modules/complianceengine/src/lib/Evaluator.h +++ b/src/modules/complianceengine/src/lib/Evaluator.h @@ -15,6 +15,13 @@ #include #include +#define AssertResult(variable, ...) \ + if (!(variable).HasValue()) \ + { \ + OsConfigLogError(context.GetLogHandle(), __VA_ARGS__); \ + return (variable).Error(); \ + } + struct json_object_t; struct json_value_t; diff --git a/src/modules/complianceengine/src/lib/InputStream.cpp b/src/modules/complianceengine/src/lib/InputStream.cpp new file mode 100644 index 0000000000..7c05a6e0df --- /dev/null +++ b/src/modules/complianceengine/src/lib/InputStream.cpp @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +namespace ComplianceEngine +{ +using std::ifstream; +using std::string; + +InputStream::InputStream(string fileName, ContextInterface& context) + : mContext(context), + mFileName(std::move(fileName)) +{ +} + +InputStream::InputStream(InputStream&& other) noexcept + : mContext(other.mContext), + mFileName(std::move(other.mFileName)), + mStream(std::move(other.mStream)) +{ + assert(mStream); + assert(mStream->is_open()); +} + +InputStream& InputStream::operator=(InputStream&& other) noexcept +{ + if (this == &other) + return *this; + + // The context reference must be constant for moves to work and it's global anyway + assert(&mContext == &other.mContext); + mFileName = std::move(other.mFileName); + mStream = std::move(other.mStream); + assert(mStream); + assert(mStream->is_open()); + return *this; +} + +Result InputStream::Open(const string& fileName, ContextInterface& context) +{ + // access lets us determine readability and obtain error codes + if (0 != ::access(fileName.c_str(), R_OK)) + { + const auto status = errno; + OsConfigLogInfo(context.GetLogHandle(), "Failed to access '%s': %s (%d)", fileName.c_str(), strerror(status), status); + return Error(string("failed to access '") + fileName + "': " + strerror(status), status); + } + + InputStream result(fileName, context); + result.mStream.reset(new std::ifstream(context.GetSpecialFilePath(fileName))); + if (!result.mStream->is_open()) + { + OsConfigLogInfo(context.GetLogHandle(), "Failed to open '%s'", result.mFileName.c_str()); + return Error(string("failed to open '") + result.mFileName + "'"); + } + + assert(result.Good()); + return Result(std::move(result)); +} + +Result InputStream::ReadLine() +{ + assert(mStream); + assert(mStream->is_open()); + if (mBytesRead >= cMaxReadSize) + { + OsConfigLogError(mContext.GetLogHandle(), "Maximum file '%s' read size reached", mFileName.c_str()); + return Error("maximum file '" + mFileName + "' read size reached", E2BIG); + } + + // We want to always check with Good() before reading + if (mStream->eof()) + { + OsConfigLogError(mContext.GetLogHandle(), "Attempted to read file '%s' after EOF", mFileName.c_str()); + return Error(string("attempted to read file '") + mFileName + "' after EOF", EBADFD); + } + + // We won't return empty strings in case an error happened previously + if (mStream->fail() || mStream->bad()) + { + OsConfigLogError(mContext.GetLogHandle(), "Attempted to read file '%s' after failure", mFileName.c_str()); + return Error(string("attempted to read file '") + mFileName + "' after failure", EBADFD); + } + // eof(), fail(), bad() and limits are already checked + assert(Good()); + + // Read the data + string line; + std::getline(*mStream, line); + + // Include the line size without the newline character + mBytesRead += line.size(); + if (!mStream->eof()) + { + // Here we know a newline has been parsed. + ++mBytesRead; + } + + if (mStream->bad()) + { + // fail() may return true here in case there's no trailing newline character, so we stick to bad(). + OsConfigLogError(mContext.GetLogHandle(), "Failed to read line from '%s': %d", mFileName.c_str(), static_cast(mStream->rdstate())); + return Error("failed to read line from '" + mFileName + "'", EBADFD); + } + + return Result(std::move(line)); +} + +bool InputStream::Good() const +{ + assert(mStream->is_open()); + return mBytesRead < cMaxReadSize && mStream->good(); +} + +bool InputStream::AtEnd() const +{ + assert(mStream->is_open()); + return mStream->eof(); +} + +const string& InputStream::GetFileName() const +{ + assert(mStream->is_open()); + return mFileName; +} + +InputStreamIterators::LinesRange InputStream::Lines() & +{ + return InputStreamIterators::LinesRange(*this); +} + +size_t InputStream::BytesRead() const +{ + return mBytesRead; +} + +namespace InputStreamIterators +{ +LinesIterator::LinesIterator(InputStream& stream) + : mStream(stream) +{ +} + +LinesIterator::LinesIterator(const LinesIterator& other) + : mStream(other.mStream), + mValue(other.mValue) +{ +} + +LinesIterator::LinesIterator(LinesIterator&& other) noexcept + : mStream(other.mStream), + mValue(std::move(other.mValue)) +{ +} + +LinesIterator& LinesIterator::operator=(const LinesIterator& other) +{ + if (this == &other) + return *this; + + assert(&mStream == &other.mStream); + mValue = other.mValue; + return *this; +} + +LinesIterator& LinesIterator::operator=(LinesIterator&& other) noexcept +{ + if (this == &other) + return *this; + + assert(&mStream == &other.mStream); + mValue = std::move(other.mValue); + return *this; +} + +LinesIterator& LinesIterator::operator++() +{ + if (mStream.Good()) + mValue = mStream.ReadLine(); + else + mValue.Reset(); + return *this; +} + +LinesIterator LinesIterator::operator++(int) +{ + LinesIterator tmp = *this; + ++*this; + return tmp; +} + +Result LinesIterator::operator*() const& noexcept(false) +{ + CheckValue(); + return mValue.Value(); +} + +Result LinesIterator::operator*() && noexcept(false) +{ + CheckValue(); + return std::move(mValue.Value()); +} + +const Result* LinesIterator::operator->() const noexcept(false) +{ + CheckValue(); + return &mValue.Value(); +} + +bool LinesIterator::IsEnd() const +{ + return mValue.HasValue(); +} + +void LinesIterator::CheckValue() const noexcept(false) +{ + if (!IsEnd()) + { + throw std::logic_error("LinesIterator: unchecked access to Value"); + } +} + +// The comparison makes only sense to compare with the end() iterator +bool LinesIterator::operator==(const LinesIterator& other) const +{ + return mValue.HasValue() == other.mValue.HasValue(); +} + +bool LinesIterator::operator!=(const LinesIterator& other) const +{ + return !(*this == other); +} + +LinesRange::LinesRange(InputStream& stream) + : mStream(stream) +{ +} + +LinesRange::LinesRange(const LinesRange& other) + : mStream(other.mStream) +{ +} + +LinesRange::LinesRange(LinesRange&& other) noexcept + : mStream(other.mStream) +{ +} + +LinesRange& LinesRange::operator=(const LinesRange& other) +{ + UNUSED(other); + assert(&mStream == &other.mStream); + return *this; +} + +LinesRange& LinesRange::operator=(LinesRange&& other) noexcept +{ + UNUSED(other); + assert(&mStream == &other.mStream); + return *this; +} + +LinesIterator LinesRange::begin() const +{ + return ++LinesIterator(mStream); +} + +LinesIterator LinesRange::end() const +{ + return LinesIterator(mStream); +} +} // namespace InputStreamIterators +} // namespace ComplianceEngine diff --git a/src/modules/complianceengine/src/lib/InputStream.h b/src/modules/complianceengine/src/lib/InputStream.h new file mode 100644 index 0000000000..4b8dcbe261 --- /dev/null +++ b/src/modules/complianceengine/src/lib/InputStream.h @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef COMPLIANCEENGINE_FILE_STREAM_H +#define COMPLIANCEENGINE_FILE_STREAM_H + +#include +#include +#include +#include +#include + +namespace ComplianceEngine +{ +namespace InputStreamIterators +{ +class LinesRange; +} // namespace InputStreamIterators + +// This class wraps the C++ std::ifstream with +// CE-specific error handling and a read size limit. +// +// 1. It uses factory method instead of constructors, to provide error handling with Result. +// 2. It guarantees an instance is always assosciated with an open file. +// 3. Any previous errors cause subsequent reads to fail +// 4. Context is used to provide mocking mechanism - filenames can be overridden for testing. +// 5. Maximum number of bytes to read is limited to cMaxReadSize. This is a soft limit as we allow +// to read last full line using std::getline. +class InputStream +{ +public: + // Maximum number of bytes read from an input stream + static constexpr std::size_t cMaxReadSize = 1024 * 1024 * 128; + + InputStream(const InputStream&) = delete; + InputStream(InputStream&&) noexcept; + InputStream& operator=(const InputStream&) = delete; + InputStream& operator=(InputStream&&) noexcept; + ~InputStream() = default; + + // Opens a file for reading. + static Result Open(const std::string& fileName, ContextInterface& context); + + // Reads a single line. + Result ReadLine(); + + // Returns true in case more bytes can be read from the file. + // This means there was no error returned so far, + // we have not reached end of the file, and we have not reached + // the read size limit. + bool Good() const; + + // Returns true in case we have reached end of the file. + bool AtEnd() const; + + // Returns the file name passed to Open. + // Note: In case of mocking, the underlying filename is not stored. + const std::string& GetFileName() const; + + // Obtains a range for line-by-line range-based iteration. + InputStreamIterators::LinesRange Lines() &; + + // Obtains the number of bytes read so far. + std::size_t BytesRead() const; + +private: + InputStream() = delete; + + explicit InputStream(std::string fileName, ContextInterface& context); + + // Holds logger handle and allows mocking. + ContextInterface& mContext; + + // The file name passed to Open. + std::string mFileName; + + // The underlying stream. + // Note: wrapped with unique ptr as fstream is not movable on RHEL7-based platforms. + std::unique_ptr mStream; + + // The number of bytes read so far. + std::size_t mBytesRead = 0; +}; + +namespace InputStreamIterators +{ +// This class allows to iterate a file line by line. +class LinesIterator +{ +public: + LinesIterator(const LinesIterator& other); + LinesIterator(LinesIterator&& other) noexcept; + LinesIterator& operator=(const LinesIterator& other); + LinesIterator& operator=(LinesIterator&& other) noexcept; + ~LinesIterator() = default; + + LinesIterator& operator++(); + LinesIterator operator++(int); + Result operator*() const& noexcept(false); + Result operator*() && noexcept(false); + const Result* operator->() const noexcept(false); + bool operator==(const LinesIterator& other) const; + bool operator!=(const LinesIterator& other) const; + + // Returns true in case this is the end iterator. + bool IsEnd() const; + +private: + friend class LinesRange; + + InputStream& mStream; + // End iterator has nullopt assigned + Optional> mValue = Optional>(); + + // Only range class(es) are able to construct new iterators + explicit LinesIterator(InputStream& stream); + + // Throws in case of end iterator dereference. + void CheckValue() const noexcept(false); +}; + +// This class provides an interface for the range-based for loops. +class LinesRange +{ +public: + explicit LinesRange(InputStream& stream); + LinesRange(const LinesRange& other); + LinesRange(LinesRange&& other) noexcept; + LinesRange& operator=(const LinesRange& other); + LinesRange& operator=(LinesRange&& other) noexcept; + ~LinesRange() = default; + + LinesIterator begin() const; // NOLINT(*-identifier-naming) + LinesIterator end() const; // NOLINT(*-identifier-naming) + +private: + InputStream& mStream; +}; +} // namespace InputStreamIterators +} // namespace ComplianceEngine + +#endif // COMPLIANCEENGINE_FILE_STREAM_H diff --git a/src/modules/complianceengine/src/lib/procedures/FileRegexMatch.cpp b/src/modules/complianceengine/src/lib/procedures/FileRegexMatch.cpp index e4fd64235c..dd8aa12d8d 100644 --- a/src/modules/complianceengine/src/lib/procedures/FileRegexMatch.cpp +++ b/src/modules/complianceengine/src/lib/procedures/FileRegexMatch.cpp @@ -3,16 +3,15 @@ #include #include #include +#include #include #include #include #include #include -#include namespace ComplianceEngine { -using std::ifstream; using std::string; using std::regex_constants::syntax_option_type; namespace @@ -33,11 +32,9 @@ Result MultilineMatch(const std::string& filename, const string& matchPatt Optional matchRegex; Optional stateRegex; - ifstream input(filename); - if (!input.is_open()) - { - return Error("Failed to open file: " + filename, errno); - } + auto input = InputStream::Open(filename, context); + AssertResult(input, "Failed to open '%s'", filename.c_str()); + try { matchRegex = regex(matchPattern, syntaxOptions.first); @@ -53,18 +50,15 @@ Result MultilineMatch(const std::string& filename, const string& matchPatt } int lineNumber = 0; - - string line; - - // Special case for empty files, read empty line then - while (getline(input, line) || lineNumber == 0) + for (const auto& line : input->Lines()) { + AssertResult(line, "Failed to read '%s'", filename.c_str()); lineNumber++; - OsConfigLogDebug(context.GetLogHandle(), "Matching line %d: '%s', pattern: '%s'", lineNumber, line.c_str(), matchPattern.c_str()); + OsConfigLogDebug(context.GetLogHandle(), "Matching line %d: '%s', pattern: '%s'", lineNumber, line->c_str(), matchPattern.c_str()); smatch match; - if (regex_search(line, match, matchRegex.Value())) + if (regex_search(line.Value(), match, matchRegex.Value())) { - OsConfigLogDebug(context.GetLogHandle(), "Matched line %d: %s", lineNumber, line.c_str()); + OsConfigLogDebug(context.GetLogHandle(), "Matched line %d: %s", lineNumber, line->c_str()); if (stateRegex.HasValue()) { assert(match.ready()); @@ -73,13 +67,13 @@ Result MultilineMatch(const std::string& filename, const string& matchPatt OsConfigLogDebug(context.GetLogHandle(), "Value to match: %s", valueToMatch.c_str()); if (regex_search(valueToMatch, stateRegex.Value())) { - OsConfigLogDebug(context.GetLogHandle(), "Matched line %d: %s", lineNumber, line.c_str()); + OsConfigLogDebug(context.GetLogHandle(), "Matched line %d: %s", lineNumber, line->c_str()); return true; } } else { - OsConfigLogDebug(context.GetLogHandle(), "Matched line %d: %s", lineNumber, line.c_str()); + OsConfigLogDebug(context.GetLogHandle(), "Matched line %d: %s", lineNumber, line->c_str()); return true; } } diff --git a/src/modules/complianceengine/src/lib/procedures/PackageInstalled.cpp b/src/modules/complianceengine/src/lib/procedures/PackageInstalled.cpp index 37ea169070..a5db3e7c01 100644 --- a/src/modules/complianceengine/src/lib/procedures/PackageInstalled.cpp +++ b/src/modules/complianceengine/src/lib/procedures/PackageInstalled.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -37,10 +38,12 @@ class ScopeGuard template ScopeGuard(Callable&& f) : f(std::forward(f)){}; + void Deactivate() { active = false; } + ~ScopeGuard() { if (active) @@ -76,32 +79,32 @@ Result DetectPackageManager(ContextInterface& context) return Error("No package manager found", ENOENT); } -Result LoadPackageCache(const std::string& path) +Result LoadPackageCache(const std::string& path, ContextInterface& context) { PackageCache cache; const std::string pkgCacheHeader = "# PackageCache "; // "# PackageCache @\n" cache.lastUpdateTime = 0; cache.packageManager = PackageManagerType::Autodetect; cache.packages.clear(); - std::ifstream cacheFile(path); - if (!cacheFile.is_open()) - { - return Error("Failed to open cache file: " + path); - } + auto cacheFile = ComplianceEngine::InputStream::Open(path, context); + AssertResult(cacheFile, "Failed to open cache file"); - std::string header; - if (!std::getline(cacheFile, header) || (0 != header.find(pkgCacheHeader, 0))) + auto header = cacheFile->ReadLine(); + AssertResult(header, "Failed to read cache header"); + if (0 != header->find(pkgCacheHeader, 0)) { + OsConfigLogError(context.GetLogHandle(), "Invalid cache file format"); return Error("Invalid cache file format"); } - auto separatorPos = header.find('@'); + const auto separatorPos = header->find('@'); if (std::string::npos == separatorPos) { + OsConfigLogError(context.GetLogHandle(), "Invalid cache file format"); return Error("Invalid cache file header format"); } - const auto packageMangerStr = header.substr(pkgCacheHeader.length(), separatorPos - pkgCacheHeader.length()); + const auto packageMangerStr = header->substr(pkgCacheHeader.length(), separatorPos - pkgCacheHeader.length()); if (packageMangerStr == "dpkg") cache.packageManager = PackageManagerType::DPKG; else if (packageMangerStr == "rpm") @@ -110,34 +113,30 @@ Result LoadPackageCache(const std::string& path) return Error("Invalid package manager type"); try { - cache.lastUpdateTime = std::stol(header.substr(separatorPos + 1)); + cache.lastUpdateTime = std::stol(header->substr(separatorPos + 1)); } catch (const std::exception&) { return Error("Invalid timestamp in cache file header"); } - std::string line; - while (std::getline(cacheFile, line)) + for (auto line : cacheFile->Lines()) { - if (!line.empty()) + AssertResult(line, "Failed to read cache entry"); + + if (!line->empty()) { - auto sepPos = line.find(' '); + auto sepPos = line->find(' '); if (sepPos == std::string::npos) { continue; // Skip lines without a space } - std::string packageName = line.substr(0, sepPos); - std::string packageVersion = line.substr(sepPos + 1); + std::string packageName = line->substr(0, sepPos); + std::string packageVersion = line->substr(sepPos + 1); cache.packages[packageName] = packageVersion; } } - if (cacheFile.bad()) - { - return Error("Error reading cache file"); - } - return cache; } @@ -454,7 +453,7 @@ Result AuditPackageInstalled(const PackageInstalledParams& params, Indic auto log = context.GetLogHandle(); PackageCache cache; - auto cacheResult = LoadPackageCache(params.test_cachePath.Value()); + auto cacheResult = LoadPackageCache(params.test_cachePath.Value(), context); bool cacheValid = true; bool cacheStale = false; if (cacheResult.HasValue()) diff --git a/src/modules/complianceengine/tests/CMakeLists.txt b/src/modules/complianceengine/tests/CMakeLists.txt index c7a98661a0..c6270abacc 100644 --- a/src/modules/complianceengine/tests/CMakeLists.txt +++ b/src/modules/complianceengine/tests/CMakeLists.txt @@ -58,6 +58,7 @@ add_executable(complianceenginetests BindingsTest.cpp CommonContextTest.cpp ComplianceEngineTest.cpp + InputStreamTest.cpp FilesystemScannerTest.cpp DistributionInfoTest.cpp EngineTest.cpp diff --git a/src/modules/complianceengine/tests/InputStreamTest.cpp b/src/modules/complianceengine/tests/InputStreamTest.cpp new file mode 100644 index 0000000000..f8be42dbde --- /dev/null +++ b/src/modules/complianceengine/tests/InputStreamTest.cpp @@ -0,0 +1,270 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +using ComplianceEngine::Error; +using ComplianceEngine::InputStream; +using ComplianceEngine::Result; +using std::string; + +class InputStreamTest : public ::testing::Test +{ +protected: + MockContext mContext; +}; + +TEST_F(InputStreamTest, DoesNotExist) +{ + auto result = InputStream::Open("nonexistentfile", mContext); + ASSERT_FALSE(result.HasValue()); + EXPECT_EQ(result.Error().code, ENOENT); +} + +TEST_F(InputStreamTest, EmptyFile) +{ + const auto filename = mContext.MakeTempfile(""); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + + // Not yet at end as we haven't read anything yet. + EXPECT_TRUE(result->Good()); + EXPECT_FALSE(result->AtEnd()); + + // This should return an empty line + auto line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string()); + + // Subsequent reads should fail as we've reached EOF + EXPECT_FALSE(result->Good()); + EXPECT_TRUE(result->AtEnd()); + line = result->ReadLine(); + ASSERT_FALSE(line.HasValue()); + EXPECT_EQ(line.Error().code, EBADFD); +} + +TEST_F(InputStreamTest, SingleLine) +{ + const auto filename = mContext.MakeTempfile("foo\n"); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + + // Not yet at end as we haven't read anything yet. + EXPECT_TRUE(result->Good()); + EXPECT_FALSE(result->AtEnd()); + + // This should return a line with 'foo' contents + auto line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string("foo")); + + // Not yet at end as we have not reached EOF state yet + EXPECT_TRUE(result->Good()); + EXPECT_FALSE(result->AtEnd()); + + // This should return an empty line + line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string()); + + // Subsequent reads should fail as we've reached EOF + EXPECT_FALSE(result->Good()); + EXPECT_TRUE(result->AtEnd()); + line = result->ReadLine(); + ASSERT_FALSE(line.HasValue()); + EXPECT_EQ(line.Error().code, EBADFD); +} + +TEST_F(InputStreamTest, MultipleLines) +{ + const auto filename = mContext.MakeTempfile("foo \n bar \r\nbaz"); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + + // Not yet at end as we haven't read anything yet. + EXPECT_TRUE(result->Good()); + EXPECT_FALSE(result->AtEnd()); + + // This should return a line with 'foo' contents + auto line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string("foo ")); + + // Not yet at end as we have not reached EOF state yet + EXPECT_TRUE(result->Good()); + EXPECT_FALSE(result->AtEnd()); + + // This should return a line with 'bar' contents. + // It will include the \r as well as it runs on Linux (std::getline behavior) + line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string(" bar \r")); + + // This should return a line with 'baz' contents + line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string("baz")); + + // We've reached the EOF as there's no line ending at the end of input. + EXPECT_FALSE(result->Good()); + EXPECT_TRUE(result->AtEnd()); + line = result->ReadLine(); + ASSERT_FALSE(line.HasValue()); + EXPECT_EQ(line.Error().code, EBADFD); +} + +TEST_F(InputStreamTest, Range_MultipleLines) +{ + const auto filename = mContext.MakeTempfile("foo\nbar\nbaz\n\n"); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + + string test; + int counter = 0; + // Test the LinesRange with iterator for range-based for loops use-case + for (auto line : result->Lines()) + { + ASSERT_TRUE(line.HasValue()); + test += line.Value(); + ++counter; + } + + EXPECT_FALSE(result->Good()); + EXPECT_EQ(test, string("foobarbaz")); + EXPECT_EQ(counter, 5); + EXPECT_FALSE(result->Good()); + EXPECT_TRUE(result->AtEnd()); +} + +TEST_F(InputStreamTest, Mocking) +{ + const auto filename = mContext.MakeTempfile("foo"); + mContext.SetSpecialFilePath("/etc/passwd", filename); + + // The /etc/passwd file should be masked by the tempfile we've just created + auto result = InputStream::Open("/etc/passwd", mContext); + ASSERT_TRUE(result.HasValue()); + auto line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string("foo")); +} + +TEST_F(InputStreamTest, LimitsHandling_1) +{ + const auto input = string(); + const auto filename = mContext.MakeTempfile(input); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + auto line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string()); + EXPECT_EQ(result->BytesRead(), input.size()); +} + +TEST_F(InputStreamTest, LimitsHandling_2) +{ + const auto input = string("foo"); + const auto filename = mContext.MakeTempfile(input); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + auto line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string("foo")); + EXPECT_EQ(result->BytesRead(), input.size()); +} + +TEST_F(InputStreamTest, LimitsHandling_3) +{ + const auto input = string("foo\n"); + const auto filename = mContext.MakeTempfile(input); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + auto line = result->ReadLine(); + ASSERT_TRUE(line.HasValue()); + EXPECT_EQ(line.Value(), string("foo")); + EXPECT_EQ(result->BytesRead(), input.size()); +} + +TEST_F(InputStreamTest, LimitsHandling_4) +{ + const auto input = string("foo\n\nbar"); + const auto filename = mContext.MakeTempfile(input); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + for (auto line : result->Lines()) + { + ASSERT_TRUE(line.HasValue()); + } + EXPECT_EQ(result->BytesRead(), input.size()); +} + +TEST_F(InputStreamTest, LimitsHandling_6) +{ + const auto input = string(InputStream::cMaxReadSize, 'x'); + const auto filename = mContext.MakeTempfile(input); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + std::size_t counter = 0; + for (auto line : result->Lines()) + { + ASSERT_TRUE(line.HasValue()); + ++counter; + } + EXPECT_EQ(counter, 1); + EXPECT_EQ(result->BytesRead(), input.size()); +} + +TEST_F(InputStreamTest, LimitsHandling_7) +{ + constexpr auto limit = InputStream::cMaxReadSize; + string input; + for (std::size_t i = 0; i < 1023; ++i) + { + input += string(InputStream::cMaxReadSize / 1024, 'x'); + input += '\n'; + } + const auto filename = mContext.MakeTempfile(input); + auto result = InputStream::Open(filename, mContext); + ASSERT_TRUE(result.HasValue()); + std::size_t counter = 0; + for (auto line : result->Lines()) + { + ASSERT_TRUE(line.HasValue()); + ++counter; + } + // +1 for the trailing empty line + EXPECT_EQ(counter, 1024); + EXPECT_TRUE(result->BytesRead() < limit); +} + +TEST_F(InputStreamTest, LimitsHandling_8) +{ + constexpr auto limit = InputStream::cMaxReadSize; + string input; + for (std::size_t i = 0; i < 1024; ++i) + { + input += string(InputStream::cMaxReadSize / 1024, 'x'); + input += '\n'; + } + const auto filename = mContext.MakeTempfile(input); + mContext.SetSpecialFilePath("/etc/passwd", filename); + auto result = InputStream::Open("/etc/passwd", mContext); + ASSERT_TRUE(result.HasValue()); + std::size_t counter = 0; + for (auto line : result->Lines()) + { + ASSERT_TRUE(line.HasValue()); + ++counter; + } + EXPECT_EQ(counter, 1024); + // We exceed the limit here, but won't be able to read next time + EXPECT_TRUE(result->BytesRead() > limit); + + // Limit reached + auto line = result->ReadLine(); + ASSERT_FALSE(line.HasValue()); + EXPECT_EQ(line.Error().code, E2BIG); +}