Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ def __init__(self, task, dag):
self._primary_id_column = task.primary_id_column
self._secondary_id_column = task.secondary_id_column
self._custom_id_column = task.custom_id_column
self._model_name = task.model_name
self._project_name = task.project_name
self._is_deleted_column = task.is_deleted_column
self._hash_column = task.hash_column
self._updated_at_column = task.updated_at_column
Expand All @@ -28,16 +26,22 @@ def __init__(self, task, dag):
self._output_type = task.output_type
self._region_name = task.region_name
self._full_refresh = task.full_refresh
self._target_case = task.target_case
self._source_case = task.source_case
self._column_mapping = task.column_mapping
self._glue_registry_name = task.glue_registry_name
self._glue_schema_name = task.glue_schema_name
self._sort_key = task.sort_key
self._custom_columns = task.custom_columns

def _generate_command(self):
command = BatchCreator._generate_command(self)

command.append(f"--num_threads={self._num_threads}")
command.append(f"--batch_size={self._batch_size}")
command.append(f"--primary_id_column={self._primary_id_column}")
command.append(f"--model_name={self._model_name}")
command.append(f"--project_name={self._project_name}")
command.append(f"--output_type={self._output_type}")
command.append(f"--glue_registry_name={self._glue_registry_name}")

if self._assume_role_arn:
command.append(f"--assume_role_arn={self._assume_role_arn}")
Expand All @@ -59,6 +63,19 @@ def _generate_command(self):
command.append(f"--region_name={self._region_name}")
if self._full_refresh:
command.append(f"--full_refresh={self._full_refresh}")
if self._target_case:
command.append(f"--target_case={self._target_case}")
if self._source_case:
command.append(f"--source_case={self._source_case}")
if self._column_mapping:
command.append(f"--column_mapping={self._column_mapping}")
if self._glue_schema_name:
command.append(f"--glue_schema_name={self._glue_schema_name}")
if self._sort_key:
command.append(f"--sort_key={self._sort_key}")
if self._custom_columns:
command.append(f"--custom_columns={self._custom_columns}")


return command

Expand Down
110 changes: 85 additions & 25 deletions dagger/pipeline/tasks/reverse_etl_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,6 @@ def init_attributes(cls, orig_cls):
required=False,
comment="The custom key column to use for the job",
),
Attribute(
attribute_name="model_name",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment="The name of the model. This is going to be a column on the target table. By default it is"
" set to the name of the input <schema>.<table>",
),
Attribute(
attribute_name="project_name",
parent_fields=["task_parameters"],
validator=str,
required=True,
comment="The name of the project. This is going to be a column on the target table.",
),
Attribute(
attribute_name="is_deleted_column",
parent_fields=["task_parameters"],
Expand Down Expand Up @@ -122,8 +107,58 @@ def init_attributes(cls, orig_cls):
validator=bool,
required=False,
comment="If set to True, the job will perform a full refresh instead of an incremental one",
),
Attribute(
attribute_name="target_case",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment="Target column case for DynamoDB. 'snake' leaves columns in snake_case; 'camel' converts to camelCase.",
),
Attribute(
attribute_name="source_case",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment="Source dataset column case. Specify the case of the incoming dataset."
),
Attribute(
attribute_name="column_mapping",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment='Optional JSON string for column mappings. Example: \'{"id": "chat_id"}\'',
),
Attribute(
attribute_name="glue_registry_name",
parent_fields=["task_parameters"],
validator=str,
required=True,
comment='AWS Glue Registry name',
),
Attribute(
attribute_name="glue_schema_name",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment='AWS Glue Schema name. output_name will be used if not provided',
),
Attribute(
attribute_name="sort_key",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment='Optional JSON string for sort key composition using #.join(). Example: \'{"sort_key": ["project", "model_name", "secondary_id", "custom_id"]}\'',
),
Attribute(
attribute_name="custom_columns",
parent_fields=["task_parameters"],
validator=str,
required=False,
comment='Optional JSON string for additional custom columns from static values. Example: \'{"custom_project": "ProjectXYZ", "model_name": "ModelABC"}\''
)


]
)

Expand All @@ -140,14 +175,19 @@ def __init__(self, name, pipeline_name, pipeline, job_config):
self._primary_id_column = self.parse_attribute("primary_id_column")
self._secondary_id_column = self.parse_attribute("secondary_id_column")
self._custom_id_column = self.parse_attribute("custom_id_column")
self._model_name = self.parse_attribute("model_name")
self._project_name = self.parse_attribute("project_name")
self._is_deleted_column = self.parse_attribute("is_deleted_column")
self._hash_column = self.parse_attribute("hash_column")
self._updated_at_column = self.parse_attribute("updated_at_column")
self._from_time = self.parse_attribute("from_time")
self._days_to_live = self.parse_attribute("days_to_live")
self._full_refresh = self.parse_attribute("full_refresh")
self._target_case = self.parse_attribute("target_case")
self._source_case = self.parse_attribute("source_case")
self._column_mapping = self.parse_attribute("column_mapping")
self._glue_registry_name = self.parse_attribute("glue_registry_name")
self._glue_schema_name = self.parse_attribute("glue_schema_name")
self._sort_key = self.parse_attribute("sort_key")
self._custom_columns = self.parse_attribute("custom_columns")

if self._hash_column and self._updated_at_column:
raise ValueError(f"ReverseETLTask: {self._name} hash_column and updated_at_column are mutually exclusive")
Expand Down Expand Up @@ -210,14 +250,6 @@ def secondary_id_column(self):
def custom_id_column(self):
return self._custom_id_column

@property
def model_name(self):
return self._model_name

@property
def project_name(self):
return self._project_name

@property
def is_deleted_column(self):
return self._is_deleted_column
Expand Down Expand Up @@ -249,3 +281,31 @@ def region_name(self):
@property
def full_refresh(self):
return self._full_refresh

@property
def target_case(self):
return self._target_case

@property
def source_case(self):
return self._source_case

@property
def column_mapping(self):
return self._column_mapping

@property
def glue_registry_name(self):
return self._glue_registry_name

@property
def glue_schema_name(self):
return self._glue_schema_name

@property
def sort_key(self):
return self._sort_key

@property
def custom_columns(self):
return self._custom_columns