Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions benchmark/simple_ml_pipeline_yesworkflow/data-acquisition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3

# @BEGIN main
# @PARAM data_url
# @OUT raw_data @URI file:data/raw_data.csv
# @OUT features @AS feature_names
# @OUT target @AS target_variable

import pandas as pd
import numpy as np
import os

def acquire_data():
"""
Acquires and processes the Boston Housing dataset.

@BEGIN fetch_raw_data
@IN data_url
@OUT raw_df @AS raw_dataset
@END fetch_raw_data

@BEGIN process_data
@IN raw_df @AS raw_dataset
@OUT data @AS feature_matrix
@OUT target @AS target_variable
@END process_data

@BEGIN create_features
@IN feature_names
@IN data @AS feature_matrix
@IN target @AS target_variable
@OUT X @AS feature_dataframe
@OUT y @AS target_series
@END create_features
"""
print("Acquiring Boston Housing dataset...")

# @BEGIN fetch_raw_data
data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
# @END fetch_raw_data

# @BEGIN process_data
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]
# @END process_data

# @BEGIN create_features
feature_names = [
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
'TAX', 'PTRATIO', 'B', 'LSTAT'
]
X = pd.DataFrame(data, columns=feature_names)
y = pd.Series(target, name='price')
# @END create_features

# @BEGIN save_data
# @IN X @AS feature_dataframe
# @IN y @AS target_series
# @OUT raw_data @URI file:data/raw_data.csv
data = pd.concat([X, y], axis=1)

os.makedirs('data', exist_ok=True)
data.to_csv('data/raw_data.csv', index=False)
# @END save_data

print(f"Dataset saved with shape: {data.shape}")
print(f"Features: {list(feature_names)}")
print(f"Target: {y.name}")
print("Data acquisition complete!")

return data

# @END main

if __name__ == "__main__":
acquire_data()
36 changes: 36 additions & 0 deletions benchmark/simple_ml_pipeline_yesworkflow/data_acquisition.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
digraph Workflow {
rankdir=LR
fontname=Courier; fontsize=18; labelloc=t
label=main
subgraph cluster_workflow_box_outer { label=""; color=black; penwidth=2
subgraph cluster_workflow_box_inner { label=""; color=white
node[shape=box style=filled fillcolor="#CCFFCC" peripheries=1 fontname=Courier]
fetch_raw_data
process_data
create_features
save_data
edge[fontname=Helvetica]
process_data -> create_features [label=target_variable]
fetch_raw_data -> process_data [label=raw_dataset]
process_data -> create_features [label=feature_matrix]
create_features -> save_data [label=feature_dataframe]
create_features -> save_data [label=target_series]
}}
subgraph cluster_input_ports_group_outer { label=""; color=white
subgraph cluster_input_ports_group_inner { label=""; color=white
node[shape=circle style=filled fillcolor="#FFFFFF" peripheries=1 fontname=Courier width=0.2]
data_url_input_port [label=""]
}}
subgraph cluster_output_ports_group_outer { label=""; color=white
subgraph cluster_output_ports_group_inner { label=""; color=white
node[shape=circle style=filled fillcolor="#FFFFFF" peripheries=1 fontname=Courier width=0.2]
raw_data_output_port [label=""]
feature_names_output_port [label=""]
target_variable_output_port [label=""]
}}
edge[fontname=Helvetica]
data_url_input_port -> fetch_raw_data [label=data_url]
edge[fontname=Helvetica]
save_data -> raw_data_output_port [label=raw_data]
process_data -> target_variable_output_port [label=target_variable]
}