Skip to content

Train

Currently we support 4 models in Rockfish SDK. Please refer to data sources as well as models to understand which kind of your dataset belongs to and which model your dataset is suitable for as well as understanding their hyperparameters.

When training, each model requires their own configuration. The configuration consists of two parts:

  • The encoder config is to indicate training data encoding schema by specifying the field name and its type.

    The common types are "categorical", "continuous" and "ignore" except for the timestamp field for time series dataset. When the type is "ignore", this field will not be considered in training. If no type is assigned, this field will automatically be assigned as type "continuous".

  • The model config is mainly used to indicate model hyperparameters.

Introduction to the configuration

First of all, users need to make sure to import rockfish.actions as ra


Model RF-Tab-GAN

Use this model if your dataset is tabular.

Example fields in tabular dataset: "A"- categorical field, "B" - continuous field, "C" field is ignored.

encoder_config = ra.TrainTabGAN.DatasetConfig(
        metadata=[
            ra.TrainTabGAN.FieldConfig(field="A", type="categorical"),
            ra.TrainTabGAN.FieldConfig(field="B", type="continuous"),
            ra.TrainTabGAN.FieldConfig(field="C", type="ignore")]
)

The model config can be updated based on the dataset and use cases. For details on other model hyperparameters, please refer to the information in models.

model_config = ra.TrainTabGAN.TrainConfig(epochs=100)

In combination of encoder config and model config, the final configuration

train_config = ra.TrainTabGAN.Config(encoder=encoder_config, tabular_gan=model_config)

After obtaining the configuration (e.g., named train_config), users need to assign this configuration to the train action based on the selected model.

train = ra.TrainTabGAN(train_config)


Model RF-Tab-Transformer

Use this model if your dataset is tabular.

Example fields in tabular dataset: "A"- categorical field, "B" - continuous field, "C" field is ignored.

encoder_config = ra.TrainTabTransformer.DatasetConfig(
        metadata=[
        ra.TrainTabTransformer.FieldConfig(field="A", type="categorical"),
        ra.TrainTabTransformer.FieldConfig(field="B", type="continuous"),
        ra.TrainTabTransformer.FieldConfig(field="C", type="ignore")]
)

The model config can be updated based on the dataset and use cases. For details on model hyperparameters, please refer to the information in models.

model_config=ra.TrainTabTransformer.TrainConfig(
        num_bootstrap=100,
        epochs=100,
        gpt2_config=ra.TrainTimeTransformer.GPT2Config(
            layer=12, head=12, embed=768
        ),
)

In combination of encoder config and model config, the final configuration

train_config = ra.TrainTabTransformer.Config(encoder=encoder_config, rtf = model_config)

After obtaining the configuration (e.g., named train_config), users need to assign this configuration to the train action based on the selected model.

train = ra.TrainTabTransformer(train_config)


Model RF-Time-GAN

Use this model if your dataset is time series.

Example fields in time series dataset: "metadata_A"- categorical field, "metadata_B" - continuous field, "metadata_C" field is ignored, "timestamp" - timestamp field, "measurement_A"- categorical field, "measurement_B" - continuous field, "measurement_C" field is ignored.

encoder_config=ra.TrainTimeGAN.DatasetConfig(
    timestamp=ra.TrainTimeGAN.TimestampConfig(field="timestamp"),
    metadata=[
        ra.TrainTimeGAN.FieldConfig(field="metadata_A", type="categorical"),
        ra.TrainTimeGAN.FieldConfig(field="metadata_B", type="continuous"),
        ra.TrainTimeGAN.FieldConfig(field="metadata_C", type="ignore")
    ],
    measurements=[
        ra.TrainTimeGAN.FieldConfig(field="measurement_A", type="categorical"),
        ra.TrainTimeGAN.FieldConfig(field="measurement_B", type="continuous"),
        ra.TrainTimeGAN.FieldConfig(field="measurement_C", type="ignore"),
    ],
)
PS. For Model RF-Time-GAN, it could have a special type ("session") for one metadata field as session_key, which is usually for a large categorial field (a field of high cardinality). The values for that metadata field are not expected to be remembered, learned or trained, such as "user ID", "series number ID", etc. However, it requires advanced knowledge about your training data as well as this model algorithm. If you want to know more or your data has one large categorical metadata field, please contact us via support@rockfish.ai

The model config can be updated based on the dataset and use cases. For details on model hyperparameters, please refer to the information in models.

model_config = ra.TrainTimeGAN.DGConfig(
        epoch = 100,
        batch_size = 100 # must be smaller than the total of number of sessions
)

In combination of encoder config and model config, the final configuration

train_config = ra.TrainTimeGAN.Config(
    encoder = encoder_config,
    doppelganger = model_config,
)

After obtaining the configuration (e.g., named train_config), users need to assign this configuration to the train action based on the selected model.

train = ra.TrainTimeGAN(train_config)


Model RF-Time-Transformer

Use this model if your dataset is time series.

Example fields in time series dataset: "metadata_A"- categorical field, "metadata_B" - continuous field, "metadata_C" field is ignored, "timestamp" - timestamp field, "measurement_A"- categorical field, "measurement_B" - continuous field, "measurement_C" field is ignored.

encoder_config = ra.TrainTabTransformer.DatasetConfig(
    timestamp=ra.TrainTimeTransformer.TimestampConfig(field="timestamp"),
    metadata=[
            ra.TrainTimeTransformer.FieldConfig(field="metadata_A", type="categorical")
            ra.TrainTimeTransformer.FieldConfig(field="metadata_B", type="continuous")
            ra.TrainTimeTransformer.FieldConfig(field="metadata_C", type="ignore")],
    measurements = [
        ra.TrainTimeTransformer.FieldConfig(field="measurement_A", type="categorical"),
        ra.TrainTabTransformer.FieldConfig(field="measurement_B", type="continuous")
        ra.TrainTabTransformer.FieldConfig(field="measurement_C", type="ignore")]
)

The model config can be updated based on the dataset and use cases. For details on model hyperparameters, please refer to the information in models.

model_config = ra.TrainTimeTransformer.TrainConfig(
        num_bootstrap=100,
        parent=ra.TrainTimeTransformer.ParentConfig(
            epochs=100,
            gpt2_config=ra.TrainTimeTransformer.GPT2Config(
                layer=12, head=12, embed=768
            ),
        ),
        child=ra.TrainTimeTransformer.ChildConfig(
            epochs=100, output_max_length=512
        ),
)

In combination of encoder config and model config, the final configuration

train_config =ra.TrainTimeTransformer.Config(encoder=encoder_config, rtf=model_config)

After obtaining the configuration (e.g., named train_config), users need to assign this configuration to the train action based on the selected model.

import rockfish.actions as ra
train = ra.TrainTimeTransformer(train_config)


Get a headache of setting up a configuration for the corresponding encoder part and model part?

No worries! Our Rockfish SDK is currently developing a recommendation engine to streamline the process of manual configuration and the subsequent train action. For details, please refer to the recommendation engine page.


Training Workflow

After getting train action, users need to build training workflow to start training job as follows.

builder = rf.WorkflowBuilder()
builder.add_dataset(dataset)
builder.add_action(train, parents=[dataset])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")