<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.2"> </script>
<p> data space with 4 lines<p>
<div id="dataSpaceWith4Lines" class="plots"></div>
.plots {
display: inline-block;
}
////
// Data
////
const trainData = {
sizeMB: [0.080, 9.000, 0.001, 0.100, 8.000, 5.000, 0.100, 6.000, 0.050, 0.500,
0.002, 2.000, 0.005, 10.00, 0.010, 7.000, 6.000, 5.000, 1.000, 1.000],
timeSec: [0.135, 0.739, 0.067, 0.126, 0.646, 0.435, 0.069, 0.497, 0.068, 0.116,
0.070, 0.289, 0.076, 0.744, 0.083, 0.560, 0.480, 0.399, 0.153, 0.149]
};
const testData = {
sizeMB: [5.000, 0.200, 0.001, 9.000, 0.002, 0.020, 0.008, 4.000, 0.001, 1.000,
0.005, 0.080, 0.800, 0.200, 0.050, 7.000, 0.005, 0.002, 8.000, 0.008],
timeSec: [0.425, 0.098, 0.052, 0.686, 0.066, 0.078, 0.070, 0.375, 0.058, 0.136,
0.052, 0.063, 0.183, 0.087, 0.066, 0.558, 0.066, 0.068, 0.610, 0.057]
};
// Convert data to Tensors.
const trainTensors = {
sizeMB: tf.tensor2d(trainData.sizeMB, [20, 1]),
timeSec: tf.tensor2d(trainData.timeSec, [20, 1])
};
const testTensors = {
sizeMB: tf.tensor2d(testData.sizeMB, [20, 1]),
timeSec: tf.tensor2d(testData.timeSec, [20, 1])
}
////
// Data Markers
////
const dataTraceTrain = {
x: trainData.sizeMB,
y: trainData.timeSec,
name: 'trainData',
mode: 'markers',
type: 'scatter',
marker: {symbol: "circle", size: 8}
};
const dataTraceTest = {
x: testData.sizeMB,
y: testData.timeSec,
name: 'testData',
mode: 'markers',
type: 'scatter',
marker: {symbol: "triangle-up", size: 10}
};
const dataTrace10Epochs = {
x: [0, 2],
y: [0, 0.01],
name: 'model after N epochs',
mode: 'lines',
line: {color: 'blue', width: 1, dash: 'dot'},
};
const dataTrace20Epochs = {
x: [0, 2],
y: [0, 0.01],
name: 'model after N epochs',
mode: 'lines',
line: {color: 'green', width: 2, dash: 'dash'}
};
const dataTrace100Epochs = {
x: [0, 2],
y: [0, 0.01],
name: 'model after N epochs',
mode: 'lines',
line: {color: 'red', width: 3, dash: 'longdash'}
};
const dataTrace200Epochs = {
x: [0, 2],
y: [0, 0.01],
name: 'model after N epochs',
mode: 'lines',
line: {color: 'black', width: 4, dash: 'solid'}
};
////
// Set up plotly plot.
////
Plotly.newPlot('dataSpaceWith4Lines', [dataTraceTrain, dataTraceTest, dataTrace10Epochs, dataTrace20Epochs, dataTrace100Epochs, dataTrace200Epochs], {
width: 700,
title: 'Model fit result',
xaxis: {
title: 'size (MB)'
},
yaxis: {
title: 'time (sec)'
}
});
////
// Construct and compile model.
////
const model = tf.sequential();
model.add(tf.layers.dense({
units: 1,
inputShape: [1],
}));
// Use a slower learning rate for illustration purposes.
const optimizer = tf.train.sgd(0.0005);
model.compile({optimizer: optimizer, loss: 'meanAbsoluteError'});
// Updates a specified line on the plot
function updateScatterWithLines(dataTrace,k, b, N, traceIndex) {
dataTrace.x = [0, 10];
dataTrace.y = [b, b + (k * 10)];
var update = {
x: [dataTrace.x],
y: [dataTrace.y],
name: 'model after ' + N + ' epochs'
}
Plotly.restyle('dataSpaceWith4Lines', update, traceIndex);
}
// Initialize to k=0, b=0 for pretty illustration purposes.
// You may want to remove this to see how the model looks with
// random initialization
let k = 0;
let b = 0;
model.setWeights([tf.tensor2d([k], [1, 1]), tf.tensor1d([b])]);
////
// Train the model within an async block
////
(async () => {
await model.fit(trainTensors.sizeMB, trainTensors.timeSec, {
epochs: 200,
// Use callbacks to orchestrate certain functions at certain triggers.
callbacks: {
onEpochEnd: async (epoch, logs) => {
k = model.getWeights()[0].dataSync()[0];
b = model.getWeights()[1].dataSync()[0];
// console.log(`epoch ${epoch}`);
if (epoch === 9) {
updateScatterWithLines(dataTrace10Epochs, k, b, 10, 2);
console.log('wrote model 10');
}
if (epoch === 19) {
updateScatterWithLines(dataTrace20Epochs, k, b, 20, 3);
console.log('wrote model 20');
}
if (epoch === 99) {
updateScatterWithLines(dataTrace100Epochs, k, b, 100, 4);
console.log('wrote model 100');
}
if (epoch === 199) {
updateScatterWithLines(dataTrace200Epochs, k, b, 200, 5);
console.log('wrote model 200');
}
await tf.nextFrame();
}
}
});
})();
This Pen doesn't use any external CSS resources.
This Pen doesn't use any external JavaScript resources.