Train
Currently we support 4 models in Rockfish SDK. Please refer to data sources as well as models to understand which kind of your dataset belongs to and which model your dataset is suitable for as well as understanding their hyperparameters.
When training, each model requires their own configuration. The configuration consists of two parts:
-
The encoder config is to indicate training data encoding schema by specifying the field name and its type.
The common types are "categorical", "continuous" and "ignore" except for the timestamp field for time series dataset. When the type is "ignore", this field will not be considered in training. If no type is assigned, this field will automatically be assigned as type "continuous".
-
The model config is mainly used to indicate model hyperparameters.
Introduction to the configuration
First of all, users need to make sure to import rockfish.actions as ra
Model RF-Tab-GAN
Use this model if your dataset is tabular.
Example fields in tabular dataset: "A"- categorical field, "B" - continuous field, "C" field is ignored.
encoder_config = ra.TrainTabGAN.DatasetConfig(
metadata=[
ra.TrainTabGAN.FieldConfig(field="A", type="categorical"),
ra.TrainTabGAN.FieldConfig(field="B", type="continuous"),
ra.TrainTabGAN.FieldConfig(field="C", type="ignore")]
)
The model config can be updated based on the dataset and use cases. For details on other model hyperparameters, please refer to the information in models.
model_config = ra.TrainTabGAN.TrainConfig(epochs=100)
In combination of encoder config and model config, the final configuration
train_config = ra.TrainTabGAN.Config(encoder=encoder_config, tabular_gan=model_config)
After obtaining the configuration (e.g., named train_config
), users need to assign this configuration to the train
action based on the selected model.
train = ra.TrainTabGAN(train_config)
Model RF-Tab-Transformer
Use this model if your dataset is tabular.
Example fields in tabular dataset: "A"- categorical field, "B" - continuous field, "C" field is ignored.
encoder_config = ra.TrainTabTransformer.DatasetConfig(
metadata=[
ra.TrainTabTransformer.FieldConfig(field="A", type="categorical"),
ra.TrainTabTransformer.FieldConfig(field="B", type="continuous"),
ra.TrainTabTransformer.FieldConfig(field="C", type="ignore")]
)
The model config can be updated based on the dataset and use cases. For details on model hyperparameters, please refer to the information in models.
model_config=ra.TrainTabTransformer.TrainConfig(
num_bootstrap=100,
epochs=100,
gpt2_config=ra.TrainTimeTransformer.GPT2Config(
layer=12, head=12, embed=768
),
)
In combination of encoder config and model config, the final configuration
train_config = ra.TrainTabTransformer.Config(encoder=encoder_config, rtf = model_config)
After obtaining the configuration (e.g., named train_config
), users need to assign this configuration to the train
action based on the selected model.
train = ra.TrainTabTransformer(train_config)
Model RF-Time-GAN
Use this model if your dataset is time series.
Example fields in time series dataset: "metadata_A"- categorical field, "metadata_B" - continuous field, "metadata_C" field is ignored, "timestamp" - timestamp field, "measurement_A"- categorical field, "measurement_B" - continuous field, "measurement_C" field is ignored.
encoder_config=ra.TrainTimeGAN.DatasetConfig(
timestamp=ra.TrainTimeGAN.TimestampConfig(field="timestamp"),
metadata=[
ra.TrainTimeGAN.FieldConfig(field="metadata_A", type="categorical"),
ra.TrainTimeGAN.FieldConfig(field="metadata_B", type="continuous"),
ra.TrainTimeGAN.FieldConfig(field="metadata_C", type="ignore")
],
measurements=[
ra.TrainTimeGAN.FieldConfig(field="measurement_A", type="categorical"),
ra.TrainTimeGAN.FieldConfig(field="measurement_B", type="continuous"),
ra.TrainTimeGAN.FieldConfig(field="measurement_C", type="ignore"),
],
)
RF-Time-GAN
, it could have a special type ("session") for one metadata field as session_key
, which is usually for a large categorial field (a field of high cardinality). The values for that metadata field are not expected to be remembered, learned or trained, such as "user ID", "series number ID", etc. However, it requires advanced knowledge about your training data as well as this model algorithm. If you want to know more or your data has one large categorical metadata field, please contact us via support@rockfish.ai
The model config can be updated based on the dataset and use cases. For details on model hyperparameters, please refer to the information in models.
model_config = ra.TrainTimeGAN.DGConfig(
epoch = 100,
batch_size = 100 # must be smaller than the total of number of sessions
)
In combination of encoder config and model config, the final configuration
train_config = ra.TrainTimeGAN.Config(
encoder = encoder_config,
doppelganger = model_config,
)
After obtaining the configuration (e.g., named train_config
), users need to assign this configuration to the train
action based on the selected model.
train = ra.TrainTimeGAN(train_config)
Model RF-Time-Transformer
Use this model if your dataset is time series.
Example fields in time series dataset: "metadata_A"- categorical field, "metadata_B" - continuous field, "metadata_C" field is ignored, "timestamp" - timestamp field, "measurement_A"- categorical field, "measurement_B" - continuous field, "measurement_C" field is ignored.
encoder_config = ra.TrainTabTransformer.DatasetConfig(
timestamp=ra.TrainTimeTransformer.TimestampConfig(field="timestamp"),
metadata=[
ra.TrainTimeTransformer.FieldConfig(field="metadata_A", type="categorical"),
ra.TrainTimeTransformer.FieldConfig(field="metadata_B", type="continuous"),
ra.TrainTimeTransformer.FieldConfig(field="metadata_C", type="ignore")],
measurements = [
ra.TrainTimeTransformer.FieldConfig(field="measurement_A", type="categorical"),
ra.TrainTabTransformer.FieldConfig(field="measurement_B", type="continuous"),
ra.TrainTabTransformer.FieldConfig(field="measurement_C", type="ignore")]
)
The model config can be updated based on the dataset and use cases. For details on model hyperparameters, please refer to the information in models.
model_config = ra.TrainTimeTransformer.TrainConfig(
num_bootstrap=100,
parent=ra.TrainTimeTransformer.ParentConfig(
epochs=100,
gpt2_config=ra.TrainTimeTransformer.GPT2Config(
layer=12, head=12, embed=768
),
),
child=ra.TrainTimeTransformer.ChildConfig(
epochs=100, output_max_length=512
),
)
In combination of encoder config and model config, the final configuration
train_config =ra.TrainTimeTransformer.Config(encoder=encoder_config, rtf=model_config)
After obtaining the configuration (e.g., named train_config
), users need to assign this configuration to the train
action based on the selected model.
import rockfish.actions as ra
train = ra.TrainTimeTransformer(train_config)
Get a headache of setting up a configuration for the corresponding encoder part and model part?
No worries! Our Rockfish SDK is currently developing a recommendation engine to streamline the process of manual configuration and the subsequent train
action. For details, please refer to the recommendation engine page.
Training Workflow
After getting train
action, users need to build training workflow to start training job as follows.
builder = rf.WorkflowBuilder()
builder.add_dataset(dataset)
builder.add_action(train, parents=[dataset])
workflow = await builder.start(conn)
print(f"Workflow: {workflow.id()}")