Central Global Model Accuracy
Use Case: Overcoming Data Privacy Constraints to Enhance Centralized ML Operations
Scenario
As the Data Science Lead of the Centralized ML Operations team, your primary responsibility is building and optimizing global machine learning models to detect anomalies effectively. A critical objective is improving model accuracy across both global and region-specific datasets.
Problem
Data is distributed across multiple geographic locations, often governed by stringent privacy and data-sharing regulations. Some regions are unable to share their datasets with the centralized ML team due to these restrictions. This missing data includes rare and critical anomalies essential for training a robust global model.
Without access to this sensitive local data, the global model lacks the necessary insights to detect anomalies effectively, leading to reduced accuracy and limited coverage of rare patterns.
Rockfish Solution
The Rockfish Generative Data Platform addresses this challenge by enabling privacy-preserving data processing and synthetic data generation directly at the restricted locations.
- Local Deployment for Privacy Compliance: Rockfish can be deployed locally (on-site or on private cloud infrastructures) in regions with strict data-sharing regulations. By processing sensitive data in-place, it ensures compliance with privacy laws.
- Privacy-Preserving Synthetic Data: Rockfish uses advanced generative AI models to create high-quality synthetic datasets that statistically mimic the original data. These synthetic datasets capture critical patterns and rare anomalies without exposing sensitive information.
- Integration for Centralized Modeling: The synthetic datasets generated in restricted locations are securely shared with the centralized ML team. By incorporating this synthetic data, the global model benefits from enhanced insights, improving anomaly detection accuracy across all regions.
Dataset used in this tutorial
In this tutorial, we’ll use the Telemetry data from the Soil Moisture Active Passive (SMAP) satellite. The dataset for this experiment has been pre-encoded and normalized, with Feature_9 as the key feature of interest.
To run the demo, you can do one of the following:
- Use the Rockfish CLI. Please refer Rockfish CLI Installation Guideto set up the Rockfish CLI.
- Use the Rockfish SDK and follow along the code snippets on this page. Please refer to the Rockfish SDK Installation Guide to install Rockfish SDK.
Rockfish CLI Demo
Initialize the Demo: Set up the necessary files for the demo:
rockfish-tutorial init central_network
rockfish-tutorial exec central_network run_small onboard
rockfish-tutorial exec central_network run_small train
rockfish-tutorial exec central_network run_small generate
Step-by-Step Guide
1: Onboarding Sample Dataset from Restricted Locations
Onboard a sample dataset to ensure compliance with privacy regulations while maintaining data fidelity.
Load Dataset and Generate Workflow
The Rockfish recommendation engine suggests a customized workflow based on your dataset:
dataset = rf.Dataset.from_csv(name="sample_data", path=filepath)
dataset_properties = DatasetPropertyExtractor(
dataset=dataset,
dataset_type=DatasetType.TABULAR
).extract()
recommender_output = Recommender(
dataset_properties=dataset_properties,
steps=[ModelSelection()]
).run()
config = recommender_output.actions[0].config()["tabular-gan"]
for k, v in model_customizations.items():
config[k] = v
print(f'{k}: {v}')
Train Models locally
Use the recommended configuration to train a generative model locally on the restricted data.
builder = rf.WorkflowBuilder()
builder.add_path(dataset, *recommender_output.actions, ra.DatasetSave(name='onboarding-fidelity-eval'))
workflow = await builder.start(conn)
Evaluate sample data quality
Ensure the synthetic data maintains high fidelity by evaluating its statistical similarity to the original data.
sns = rl.vis.plot_kde([dataset, synthetic_data], field='feature' ,palette=['g', 'b'])
sns.set_xlabels("Normalized Feature")
plt.show()
fidelity_score = rl.metrics.marginal_dist_score(dataset, synthetic_data)
return fidelity_score
To run this in the terminal:
rockfish-tutorial exec central_network <playthrough-option=view_results | run_small | run_entire> onboard <--logging True | False> <--save-path <path>>
2: Continuous Data Flow for ongoing improvements
Once you’ve finalized the workflow, you can set up the Rockfish runtime for continuous training with new data streams:
# load runtime_conf (obtained after onboarding)
runtime_conf = pickle.load(open("runtime_conf", "rb"))
# set up a datastream for the workflow
datastream = runtime_conf.actions["datastream-load"]
# start runtime
runtime_workflow = await runtime_conf.start(conn)
async for i, path in enumerate(dataset_paths):
dataset = rf.Dataset.from_csv("train", f"{path}")
await runtime_workflow.write_datastream(datastream, dataset)
print(f"Training model {i} on {path}")
# add labels
for i, path in enumerate(dataset_paths):
model = await runtime_workflow.models().nth(i)
label = path[10:-4]
await model.add_labels(conn, kind=f"model_{label}")
print(f"Finished training model {i} on {path}")
To run in the terminal:
rockfish-tutorial exec central_network <playthrough-option=view_results | run_small | run_entire> train <--logging True | False> <--save-path <path>>
You can integrate with Databricks or S3 for cloud-based data syncing and synthetic data generation. Refer Rockifsh Integration guide for details.
3: Generate Synthetic Data for Restricted Locations
After training models on restricted data, use them to generate synthetic data.
model = await conn.list_models(
labels={"kind": model_label}
).last()
Generate the synthetic dataset using the queried model:
generate_conf = pickle.load(open("generate_conf", "rb"))
builder = rf.WorkflowBuilder()
builder.add_path(model, generate_conf, ra.DatasetSave(name="synthetic"))
workflow = await builder.start(conn)
async for log in workflow.logs():
print(log)
syn_dataset = await (await workflow.datasets().concat(conn)).to_local(conn)
Enhancing the Global Model with Synthetic Data
Combine synthetic data from all locations to improve the accuracy of the global model. For this tutorial, we use Facebook’s Prophet model for anomaly detection.
# reformat the data to match the Prophet model
data = data[["feature", "timestamp"]].rename(columns={"timestamp": "ds", "feature": "y"})
data["ds"] = pd.to_datetime(data["ds"], format="%Y-%m-%d %H:%M:%S", exact=False)
data = data.dropna()
# read the test data
test = pd.read_csv("datafiles/test.csv")
test_labels = pd.read_csv("datafiles/test_label.csv")
# train the model
model = prophet.Prophet(daily_seasonality=True)
np.random.seed(500)
model.fit(data)
future = pd.DataFrame()
future["ds"] = timestamps
# make predictions using learnt model
forecast = model.predict(future)
Visualization:
Highlight true and false positives to evaluate model performance:
# get anomaly labels
pred_labels = np.where(test[feature] <= forecast["yhat_upper"], 0, 1)
# plot
fig, ax = plt.subplots()
x = pd.to_datetime(forecast["ds"]) # plot timestamps on x axis
ax.plot(x, test[feature], "g", label="True Value")
ax.plot(x, forecast["yhat"], "b", label="Predicted Value")
ax.fill_between(x, forecast["yhat_lower"], forecast["yhat_upper"], alpha=0.1)
ax.set_ylim(0.4, 1.03)
# mark true and false positives
tp_idxs = np.where((test_labels["label"] == 1) & (pred_labels == 1))[0] # get idxs for true positives
fp_idxs = np.where((test_labels["label"] == 0) & (pred_labels == 1))[0] # get idxs for false positives
ax.plot(x.iloc[tp_idxs], test[feature].iloc[tp_idxs], "r.", label="True Anomaly")
ax.plot(x.iloc[fp_idxs], test[feature].iloc[fp_idxs], "k.", label="False Anomaly")
ax.legend()
plt.xlabel("Time")
plt.ylabel("Feature")
plt.show()
Evaluate the model with a confusion matrix (table summarizing the performance of a classification model by comparing actual values, the true labels, to the predicted values from the model)
print(f"F1 Score: {f1_score(y_true=test_labels['label'], y_pred=pred_labels):.2f}")
tn, fp, fn, tp = confusion_matrix(y_true=test_labels["label"], y_pred=pred_labels).ravel()
print(f"TP: {tp}, FP: {fp}")
print(f"True Positive Rate: {((tp / (tp + fn)) * 100):.2f}%")
print(f"False Positive Rate: {((fp / (tn + fp)) * 100):.2f}%")
print(f'Accuracy: {((tp + tn) / (tp + tn + fp + fn) * 100):.2f}%')
To run in the terminal:
rockfish-tutorial exec central_network <playthrough-option=view_results | run_small | run_entire> generate <--logging True | False> <--save-path <path>>
Results
- Improved Global Model Accuracy: Incorporating synthetic data enables better anomaly detection globally and within each region.
- Privacy Compliance: Rockfish ensures sensitive data never leaves restricted locations, maintaining compliance with privacy regulations.
- Scalability: Continuous data flow and model updates support dynamic environments and evolving patterns.
Next Steps
- Experiment with Diverse Datasets: Apply this workflow to datasets from different industries to test Rockfish’s adaptability.
- Optimize Hyperparameters: Leverage Rockfish’s tuning capabilities to refine synthetic data quality and model performance.