diff --git a/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala b/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala index 976c489f..b74241b4 100644 --- a/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/redshift/DefaultSource.scala @@ -40,6 +40,11 @@ class DefaultSource( */ def this() = this(DefaultJDBCWrapper, awsCredentials => new AmazonS3Client(awsCredentials)) + /** + * Constructor to provide a custom S3 client factory + */ + def this(s3ClientFactory: AWSCredentialsProvider => AmazonS3Client) = this(DefaultJDBCWrapper, s3ClientFactory) + /** * Create a new RedshiftRelation instance using parameters from Spark SQL DDL. Resolves the schema * using JDBC connection over provided URL, which must contain credentials. diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala index 31dc11b2..a9195179 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala @@ -55,12 +55,12 @@ private[redshift] case class RedshiftRelation( } private val tableNameOrSubquery = - params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get + params.query.map(q => s"($q) AS q").orElse(params.table.map(_.toString)).get override lazy val schema: StructType = { userSchema.getOrElse { val tableNameOrSubquery = - params.query.map(q => s"($q)").orElse(params.table.map(_.toString)).get + params.query.map(q => s"($q) AS q").orElse(params.table.map(_.toString)).get val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, params.credentials) try { jdbcWrapper.resolveTable(conn, tableNameOrSubquery) diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala index ac2a644a..5790bd62 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -181,7 +181,7 @@ class RedshiftSourceSuite |UNLOAD \('SELECT "testbyte", "testbool" FROM | \(select testbyte, testbool | from test_table - | where teststring = \\'\\\\\\\\Unicode\\'\\'s樂趣\\'\) '\) + | where teststring = \\'\\\\\\\\Unicode\\'\\'s樂趣\\'\) AS q '\) """.stripMargin.lines.map(_.trim).mkString(" ").trim.r val query = """select testbyte, testbool from test_table where teststring = '\\Unicode''s樂趣'""" @@ -196,7 +196,7 @@ class RedshiftSourceSuite { val params = defaultParams + ("dbtable" -> s"($query)") val mockRedshift = - new MockRedshift(defaultParams("url"), Map(params("dbtable") -> querySchema)) + new MockRedshift(defaultParams("url"), Map(s"${params("dbtable")} AS q" -> querySchema)) val relation = new DefaultSource( mockRedshift.jdbcWrapper, _ => mockS3Client).createRelation(testSqlContext, params) assert(testSqlContext.baseRelationToDataFrame(relation).collect() === expectedValues) @@ -207,7 +207,7 @@ class RedshiftSourceSuite // Test with query parameter { val params = defaultParams - "dbtable" + ("query" -> query) - val mockRedshift = new MockRedshift(defaultParams("url"), Map(s"($query)" -> querySchema)) + val mockRedshift = new MockRedshift(defaultParams("url"), Map(s"($query) AS q" -> querySchema)) val relation = new DefaultSource( mockRedshift.jdbcWrapper, _ => mockS3Client).createRelation(testSqlContext, params) assert(testSqlContext.baseRelationToDataFrame(relation).collect() === expectedValues)