Skip to content

Latest commit

 

History

History
433 lines (342 loc) · 11.3 KB

File metadata and controls

433 lines (342 loc) · 11.3 KB

TabDM API Reference

This reference documents the public generation, transformation, and evaluation APIs exported from tabdm.

Generation API

generate_synthetic_data

generate_synthetic_data(
    dataframe,
    *,
    num_rows=None,
    discrete_columns=None,
    column_metadata=None,
    target_column=None,
    sensitive_columns=None,
    condition_on=None,
    conditions=None,
    condition_strategy="prior",
    hidden_dims=(256, 256),
    time_embedding_dim=64,
    timesteps=96,
    sample_steps=24,
    epochs=120,
    batch_size=512,
    learning_rate=1e-3,
    weight_decay=1e-6,
    beta_start=1e-4,
    beta_end=0.02,
    dropout=0.0,
    discrete_loss_weight=2.0,
    prediction_clip=1.5,
    grad_clip_norm=1.0,
    device="cpu",
    random_state=None,
    verbose=False,
    return_model=False,
)

Fits a TabDM model and returns synthetic rows.

Returns:

  • pandas.DataFrame when return_model=False
  • SyntheticDataResult when return_model=True

SyntheticDataResult fields:

Field Meaning
synthetic_data Generated dataframe.
model Fitted TabDM instance.
discrete_columns Resolved discrete column names.
condition_columns Resolved condition column names.

fit_tabdm

fit_tabdm(
    dataframe,
    *,
    discrete_columns=None,
    column_metadata=None,
    target_column=None,
    sensitive_columns=None,
    condition_on=None,
    hidden_dims=(256, 256),
    time_embedding_dim=64,
    timesteps=96,
    sample_steps=24,
    epochs=120,
    batch_size=512,
    learning_rate=1e-3,
    weight_decay=1e-6,
    beta_start=1e-4,
    beta_end=0.02,
    dropout=0.0,
    discrete_loss_weight=2.0,
    prediction_clip=1.5,
    grad_clip_norm=1.0,
    device="cpu",
    random_state=None,
    verbose=False,
)

Fits and returns a reusable TabDM instance. Use this API when you want to generate multiple samples from the same fitted model.

TabDMConfig

TabDMConfig(
    hidden_dims=(256, 256),
    time_embedding_dim=64,
    timesteps=96,
    sample_steps=24,
    epochs=120,
    batch_size=512,
    learning_rate=1e-3,
    weight_decay=1e-6,
    beta_start=1e-4,
    beta_end=0.02,
    dropout=0.0,
    discrete_loss_weight=2.0,
    prediction_clip=1.5,
    grad_clip_norm=1.0,
    device="cpu",
    random_state=None,
    verbose=False,
)

Configuration for direct TabDM use.

TabDM.fit

model.fit(
    train_data,
    *,
    discrete_columns=None,
    column_metadata=None,
    target_column=None,
    sensitive_columns=None,
    condition_on=None,
)

Fits a model to a dataframe or 2D NumPy array. Dataframe input is recommended because it preserves column names and dtypes.

TabDM.sample and TabDM.generate

model.sample(
    num_rows,
    *,
    conditions=None,
    condition_strategy="prior",
    random_state=None,
)

generate is an alias for sample.

Arguments:

Argument Meaning
num_rows Number of synthetic rows. Must be positive.
conditions Optional fixed or row-wise condition values.
condition_strategy prior or balanced sampling for condition rows.
random_state Reproducible sampling seed.

Generation Parameters

Data Parameters

Parameter Type Notes
dataframe pandas.DataFrame Required for high-level APIs. Must not be empty.
train_data dataframe or 2D array Accepted by TabDM.fit.
num_rows positive int If omitted in generate_synthetic_data, defaults to the training row count.
discrete_columns sequence of strings Optional. If omitted, object/string/categorical/bool columns are inferred.
column_metadata mapping Optional metadata for special numeric or ordinal handling.

Conditioning Parameters

Parameter Type Notes
target_column string Treated as a condition column. The model generates features conditional on it.
sensitive_columns sequence of strings Also treated as condition columns.
condition_on sequence of strings Additional condition columns.
conditions dataframe or mapping Values to use during generation.
condition_strategy "prior" or "balanced" Used when condition rows must be sampled from training data.

Condition column order is:

  1. condition_on
  2. target_column
  3. sensitive_columns

Duplicate names are removed while preserving that order.

Model Parameters

Parameter Type Notes
hidden_dims tuple of positive ints MLP denoiser hidden sizes.
time_embedding_dim positive int Timestep embedding dimension.
timesteps int greater than 1 Number of training diffusion steps.
sample_steps positive int Number of deterministic reverse steps.
epochs positive int Number of training epochs.
batch_size positive int Training and sampling chunk size.
learning_rate positive float AdamW learning rate.
weight_decay non-negative float AdamW weight decay.
beta_start float Must satisfy 0 < beta_start < beta_end < 1.
beta_end float Must satisfy 0 < beta_start < beta_end < 1.
dropout float Dropout probability for hidden layers.
discrete_loss_weight positive float Weight for softmax/discrete spans in reconstruction loss.
prediction_clip positive float Clamp for predicted transformed features.
grad_clip_norm non-negative float Gradient clipping threshold.
device string "cpu" or CUDA device. CUDA falls back to CPU if unavailable.
random_state int or None Seeds fitting and sampling.
verbose bool Prints periodic training loss.

Column Metadata

column_metadata is a mapping from column name to metadata.

column_metadata = {
    "grade_band": {"type": "ordinal", "order": ["low", "mid", "high"]},
    "incident_count": {"type": "count"},
    "tuition_paid": {"type": "positive_continuous"},
}

Supported metadata:

Metadata Required fields Behavior
{"type": "ordinal"} optional order Encodes ordered categories as one scalar. If order is omitted, values are ordered numerically when possible, otherwise lexicographically.
{"type": "count"} none Uses log1p; clips to non-negative values; rounds on inverse transform.
{"type": "positive_continuous"} none Uses log1p; clips to non-negative values.

Transformer API

DataTransformer

from tabdm import DataTransformer

transformer = DataTransformer()
transformer.fit(real, discrete_columns=["job"], column_metadata=metadata)
matrix = transformer.transform(real)
restored = transformer.inverse_transform(matrix)

Main attributes after fitting:

Attribute Meaning
output_dimensions Number of transformed columns.
output_info_list Span metadata used for loss weighting and decoding.
discrete_column_infos Discrete column metadata for category lookup.
column_transform_infos Per-column transform metadata.

infer_discrete_columns

infer_discrete_columns(dataframe)

Returns columns whose dtype is object, string, categorical, or boolean.

Evaluation API

evaluate_synthetic

evaluate_synthetic(
    real,
    synthetic,
    target_column=None,
    *,
    test_size=0.25,
    random_state=None,
    task_type="auto",
    include_trust=True,
    metrics="all",
)

Returns a nested dictionary with selected metric groups.

Top-level keys:

Key Meaning
target_column Target column used for utility, or None.
task_type Resolved task type.
metric_groups Metric groups included in the report.
rows Row counts for real, synthetic, and split real data when used.
schema Present when requested.
distribution Present when requested.
validity Present when requested.
utility Present when requested and target_column is set.
trust Present when requested and include_trust=True.

Metric selection:

evaluate_synthetic(real, synthetic, metrics="all")
evaluate_synthetic(real, synthetic, metrics=("schema", "distribution"))
evaluate_synthetic(real, synthetic, include_trust=False)

Available metric groups:

  • schema
  • distribution
  • validity
  • trust
  • utility

schema_report

schema_report(real, synthetic)

Returns:

  • real_columns
  • synthetic_columns
  • common_columns
  • missing_columns
  • extra_columns
  • dtype_mismatches
  • is_column_compatible

distribution_report

distribution_report(real, synthetic)

Returns categorical, numeric, and correlation summaries:

  • categorical mean/worst total variation distance
  • numeric mean/worst Kolmogorov-Smirnov distance
  • mean absolute numeric correlation delta

Lower values indicate closer marginal or correlation fidelity.

validity_report

validity_report(real, synthetic)

Returns:

  • bounds: numeric values outside real-data min/max bounds
  • unseen_categories: synthetic categorical values not observed in real data

evaluate_utility

evaluate_utility(
    train_df,
    eval_df,
    synthetic_df,
    target_column,
    task_type,
)

Runs train-on-synthetic, test-on-real evaluation.

For classification, returns:

  • baseline_real_train_classes
  • synthetic_train_classes
  • baseline_real_accuracy
  • synthetic_accuracy
  • baseline_real_macro_f1
  • synthetic_macro_f1

For regression, returns:

  • baseline_real_r2
  • synthetic_r2
  • baseline_real_mae
  • synthetic_mae

trust_report

trust_report(train_df, eval_df, synthetic_df)

Returns privacy-screening diagnostics:

  • exact synthetic row matches against training rows
  • nearest-neighbor distances to training rows
  • risk level derived from exact matches and distance signals
  • disclaimer that these metrics do not prove anonymization or compliance

Standalone Metrics

categorical_tvd(real_column, synthetic_column)
numeric_ks(real_column, synthetic_column)
numeric_correlation_delta(real_dataframe, synthetic_dataframe)
exact_match_rate(real_dataframe, synthetic_dataframe)
nearest_neighbor_privacy(train_df, eval_df, synthetic_df)
infer_task_type(frame, target_column)

infer_task_type returns:

  • classification for object, string, categorical, boolean, and low-cardinality integer targets
  • regression for floating numeric targets

Reproducibility Notes

  • random_state on fitting seeds Python, NumPy, and Torch.
  • random_state on sample or generate seeds sampling noise.
  • Results can still vary across hardware, Torch versions, or CUDA kernels.
  • For empirical claims, run multiple seeds and record the exact model and data configuration used to produce the synthetic data.

Validation and Safety Notes

TabDM bounds generated numeric values to the training range and decodes categorical columns to known categories. These constraints reduce invalid outputs, but they are not privacy guarantees.

Before publishing or sharing generated data, run:

  1. schema and validity checks
  2. distribution fidelity checks
  3. task utility checks, if there is a downstream target
  4. privacy-screening checks, including exact matches and nearest-neighbor diagnostics
  5. domain-specific validation for impossible or regulated values