Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ def __init__(self, task, dag):
self._custom_columns = task.custom_columns
self._input_table_columns_to_include = task.input_table_columns_to_include
self._input_table_columns_to_exclude = task.input_table_columns_to_exclude
self._file_format = task.file_format
self._file_prefix = task.file_prefix

def _generate_command(self):
command = BatchCreator._generate_command(self)

command.append(f"--num_threads={self._num_threads}")
command.append(f"--batch_size={self._batch_size}")
command.append(f"--primary_id_column={self._primary_id_column}")
command.append(f"--output_type={self._output_type}")
command.append(f"--glue_registry_name={self._glue_registry_name}")

if self._num_threads:
command.append(f"--num_threads={self._num_threads}")
if self._batch_size:
command.append(f"--batch_size={self._batch_size}")
if self._assume_role_arn:
command.append(f"--assume_role_arn={self._assume_role_arn}")
if self._secondary_id_column:
Expand Down Expand Up @@ -81,6 +84,10 @@ def _generate_command(self):
command.append(f"--input_table_columns_to_include={self._input_table_columns_to_include}")
if self._input_table_columns_to_exclude:
command.append(f"--input_table_columns_to_exclude={self._input_table_columns_to_exclude}")
if self._file_format:
command.append(f"--file_format={self._file_format}")
if self._file_prefix:
command.append(f"--file_prefix={self._file_prefix}")

return command

Expand Down
17 changes: 12 additions & 5 deletions dagger/pipeline/ios/s3_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ class S3IO(IO):
def init_attributes(cls, orig_cls):
cls.add_config_attributes(
[
Attribute(
attribute_name="region_name",
required=False,
comment="Only needed for cross region S3 buckets"
),
Attribute(
attribute_name="s3_protocol",
required=False,
Expand All @@ -24,22 +29,21 @@ def init_attributes(cls, orig_cls):
def __init__(self, io_config, config_location):
super().__init__(io_config, config_location)

self._region_name = self.parse_attribute("region_name")
self._s3_protocol = self.parse_attribute("s3_protocol") or "s3"
self._bucket = normpath(self.parse_attribute("bucket"))
self._path = normpath(self.parse_attribute("path"))

def alias(self):
return "s3://{path}".format(path=join(self._bucket, self._path))
return f"s3://{self._region_name or ''}/{join(self._bucket, self._path)}"

@property
def rendered_name(self):
return "{protocol}://{path}".format(
protocol=self._s3_protocol, path=join(self._bucket, self._path)
)
return f"{self._s3_protocol}://{join(self._bucket, self._path)}"

@property
def airflow_name(self):
return "s3-{}".format(join(self._bucket, self._path).replace("/", "-"))
return f"s3-{'-'.join([name_part for name_part in [self._region_name, join(self._bucket, self._path).replace('/', '-')] if name_part])}"

@property
def bucket(self):
Expand All @@ -49,3 +53,6 @@ def bucket(self):
def path(self):
return self._path

@property
def region_name(self):
return self._region_name
24 changes: 24 additions & 0 deletions dagger/pipeline/tasks/reverse_etl_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,20 @@ def init_attributes(cls, orig_cls):
validator=str,
required=False,
comment='Optional comma-separated list of columns to exclude from the job. Example: \'column1,column2,column3\', if not provided, all columns of input table will be included',
),
Attribute(
attribute_name="file_format",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment="File format for S3 output: 'json' or 'parquet' (required when output_type is 's3')",
),
Attribute(
attribute_name="file_prefix",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment="File prefix for S3 output files",
)
]
)
Expand Down Expand Up @@ -202,6 +216,8 @@ def __init__(self, name, pipeline_name, pipeline, job_config):
self._custom_columns = self.parse_attribute("custom_columns")
self._input_table_columns_to_include = self.parse_attribute("input_table_columns_to_include")
self._input_table_columns_to_exclude = self.parse_attribute("input_table_columns_to_exclude")
self._file_format = self.parse_attribute("file_format")
self._file_prefix = self.parse_attribute("file_prefix")

if self._hash_column and self._updated_at_column:
raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive")
Expand Down Expand Up @@ -334,3 +350,11 @@ def input_table_columns_to_include(self):
@property
def input_table_columns_to_exclude(self):
return self._input_table_columns_to_exclude

@property
def file_format(self):
return self._file_format

@property
def file_prefix(self):
return self._file_prefix
3 changes: 2 additions & 1 deletion tests/fixtures/pipeline/ios/s3_io.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
type: s3
name: test_s3
bucket: test_bucket
path: test_path
path: test_path
region_name: eu_west_1
17 changes: 13 additions & 4 deletions tests/pipeline/ios/test_s3_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,23 @@ def setUp(self) -> None:
def test_properties(self):
db_io = s3_io.S3IO(self.config, "/")

self.assertEqual(db_io.alias(), "s3://test_bucket/test_path")
self.assertEqual(db_io.alias(), "s3://eu_west_1/test_bucket/test_path")
self.assertEqual(db_io.rendered_name, "s3://test_bucket/test_path")
self.assertEqual(db_io.airflow_name, "s3-test_bucket-test_path")
self.assertEqual(db_io.airflow_name, "s3-eu_west_1-test_bucket-test_path")

def test_with_protocol(self):
self.config['s3_protocol'] = 's3a'
db_io = s3_io.S3IO(self.config, "/")

self.assertEqual(db_io.alias(), "s3://test_bucket/test_path")
self.assertEqual(db_io.alias(), "s3://eu_west_1/test_bucket/test_path")
self.assertEqual(db_io.rendered_name, "s3a://test_bucket/test_path")
self.assertEqual(db_io.airflow_name, "s3-test_bucket-test_path")
self.assertEqual(db_io.airflow_name, "s3-eu_west_1-test_bucket-test_path")

def test_with_region_name(self):
self.config['region_name'] = 'us-west-2'
db_io = s3_io.S3IO(self.config, "/")

self.assertEqual(db_io.alias(), "s3://us-west-2/test_bucket/test_path")
self.assertEqual(db_io.rendered_name, "s3://test_bucket/test_path")
self.assertEqual(db_io.airflow_name, "s3-us-west-2-test_bucket-test_path")
self.assertEqual(db_io.region_name, "us-west-2")