<!DOCTYPE html>
<meta charset="utf-8">
<style>
.sgd {
stroke: black;
}
.momentum {
stroke: blue;
}
.rmsprop {
stroke: red;
}
.adam {
stroke: green;
}
.SGD {
fill: black;
}
.Momentum {
fill: blue;
}
.RMSProp {
fill: red;
}
.Adam {
fill: green;
}
circle:hover {
fill-opacity: .3;
}
</style>
<body>
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="https://d3js.org/d3-contour.v1.min.js"></script>
<script src="https://d3js.org/d3-scale-chromatic.v1.min.js"></script>
<script>
var width = 960,
height = 500,
nx = parseInt(width / 5), // grid sizes
ny = parseInt(height / 5),
h = 1e-7, // step used when approximating gradients
drawing_time = 30; // max time to run optimization
var svg = d3.select("body")
.append("svg")
.attr("width", width)
.attr("height", height);
// Parameters describing where function is defined
var domain_x = [-2, 2],
domain_y = [-2, 2],
domain_f = [-2, 8],
contour_step = 0.5; // Step size of contour plot
var scale_x = d3.scaleLinear()
.domain([0, width])
.range(domain_x);
var scale_y = d3.scaleLinear()
.domain([0, height])
.range(domain_y);
var thresholds = d3.range(domain_f[0], domain_f[1], contour_step);
var color_scale = d3.scaleLinear()
.domain(d3.extent(thresholds))
.interpolate(function() { return d3.interpolateYlGnBu; });
var function_g = svg.append("g").on("mousedown", mousedown),
gradient_path_g = svg.append("g"),
menu_g = svg.append("g");
/*
* Set up the function and gradients
*/
/* Value of f at (x, y) */
function f(x, y) {
return -2 * Math.exp(-((x - 1) * (x - 1) + y * y) / .2) + -3 * Math.exp(-((x + 1) * (x + 1) + y * y) / .2) + x * x + y * y;
}
/* Returns gradient of f at (x, y) */
function grad_f(x,y) {
var grad_x = (f(x + h, y) - f(x, y)) / h
grad_y = (f(x, y + h) - f(x, y)) / h
return [grad_x, grad_y];
}
/* Returns values of f(x,y) at each point on grid as 1 dim array. */
function get_f_values(nx, ny) {
var grid = new Array(nx * ny);
for (i = 0; i < nx; i++) {
for (j = 0; j < ny; j++) {
var x = scale_x( parseFloat(i) / nx * width ),
y = scale_y( parseFloat(j) / ny * height );
// Set value at ordering expected by d3.contour
grid[i + j * nx] = f(x, y);
}
}
return grid;
}
/*
* Set up the contour plot
*/
var contours = d3.contours()
.size([nx, ny])
.thresholds(thresholds);
var f_values = get_f_values(nx, ny);
function_g.selectAll("path")
.data(contours(f_values))
.enter().append("path")
.attr("d", d3.geoPath(d3.geoIdentity().scale(width / nx)))
.attr("fill", function(d) { return color_scale(d.value); })
.attr("stroke", "none");
/*
* Set up buttons
*/
var draw_bool = {"SGD" : true, "Momentum" : true, "RMSProp" : true, "Adam" : true};
var buttons = ["SGD", "Momentum", "RMSProp", "Adam"];
menu_g.append("rect")
.attr("x", 0)
.attr("y", height - 40)
.attr("width", width)
.attr("height", 40)
.attr("fill", "white")
.attr("opacity", 0.2);
menu_g.selectAll("circle")
.data(buttons)
.enter()
.append("circle")
.attr("cx", function(d,i) { return width/4 * (i + 0.25);} )
.attr("cy", height - 20)
.attr("r", 10)
.attr("stroke-width", 0.5)
.attr("stroke", "black")
.attr("class", function(d) { console.log(d); return d;})
.attr("fill-opacity", 0.5)
.attr("stroke-opacity", 1)
.on("mousedown", button_press);
menu_g.selectAll("text")
.data(buttons)
.enter()
.append("text")
.attr("x", function(d,i) { return width/4 * (i + 0.25) + 18;} )
.attr("y", height - 14)
.text(function(d) { return d; })
.attr("text-anchor", "start")
.attr("font-family", "Helvetica Neue")
.attr("font-size", 15)
.attr("font-weight", 200)
.attr("fill", "white")
.attr("fill-opacity", 0.8);
function button_press() {
var type = d3.select(this).attr("class")
if (draw_bool[type]) {
d3.select(this).attr("fill-opacity", 0);
draw_bool[type] = false;
} else {
d3.select(this).attr("fill-opacity", 0.5)
draw_bool[type] = true;
}
}
/*
* Set up optimization/gradient descent functions.
* SGD, Momentum, RMSProp, Adam.
*/
function get_sgd_path(x0, y0, learning_rate, num_steps) {
var sgd_history = [{"x": scale_x.invert(x0), "y": scale_y.invert(y0)}];
var x1, y1, gradient;
for (i = 0; i < num_steps; i++) {
gradient = grad_f(x0, y0);
x1 = x0 - learning_rate * gradient[0]
y1 = y0 - learning_rate * gradient[1]
sgd_history.push({"x" : scale_x.invert(x1), "y" : scale_y.invert(y1)})
x0 = x1
y0 = y1
}
return sgd_history;
}
function get_momentum_path(x0, y0, learning_rate, num_steps, momentum) {
var v_x = 0,
v_y = 0;
var momentum_history = [{"x": scale_x.invert(x0), "y": scale_y.invert(y0)}];
var x1, y1, gradient;
for (i=0; i < num_steps; i++) {
gradient = grad_f(x0, y0)
v_x = momentum * v_x - learning_rate * gradient[0]
v_y = momentum * v_y - learning_rate * gradient[1]
x1 = x0 + v_x
y1 = y0 + v_y
momentum_history.push({"x" : scale_x.invert(x1), "y" : scale_y.invert(y1)})
x0 = x1
y0 = y1
}
return momentum_history
}
function get_rmsprop_path(x0, y0, learning_rate, num_steps, decay_rate, eps) {
var cache_x = 0,
cache_y = 0;
var rmsprop_history = [{"x": scale_x.invert(x0), "y": scale_y.invert(y0)}];
var x1, y1, gradient;
for (i = 0; i < num_steps; i++) {
gradient = grad_f(x0, y0)
cache_x = decay_rate * cache_x + (1 - decay_rate) * gradient[0] * gradient[0]
cache_y = decay_rate * cache_y + (1 - decay_rate) * gradient[1] * gradient[1]
x1 = x0 - learning_rate * gradient[0] / (Math.sqrt(cache_x) + eps)
y1 = y0 - learning_rate * gradient[1] / (Math.sqrt(cache_y) + eps)
rmsprop_history.push({"x" : scale_x.invert(x1), "y" : scale_y.invert(y1)})
x0 = x1
y0 = y1
}
return rmsprop_history;
}
function get_adam_path(x0, y0, learning_rate, num_steps, beta_1, beta_2, eps) {
var m_x = 0,
m_y = 0,
v_x = 0,
v_y = 0;
var adam_history = [{"x": scale_x.invert(x0), "y": scale_y.invert(y0)}];
var x1, y1, gradient;
for (i = 0; i < num_steps; i++) {
gradient = grad_f(x0, y0)
m_x = beta_1 * m_x + (1 - beta_1) * gradient[0]
m_y = beta_1 * m_y + (1 - beta_1) * gradient[1]
v_x = beta_2 * v_x + (1 - beta_2) * gradient[0] * gradient[0]
v_y = beta_2 * v_y + (1 - beta_2) * gradient[1] * gradient[1]
x1 = x0 - learning_rate * m_x / (Math.sqrt(v_x) + eps)
y1 = y0 - learning_rate * m_y / (Math.sqrt(v_y) + eps)
adam_history.push({"x" : scale_x.invert(x1), "y" : scale_y.invert(y1)})
x0 = x1
y0 = y1
}
return adam_history;
}
/*
* Functions necessary for path visualizations
*/
var line_function = d3.line()
.x(function(d) { return d.x; })
.y(function(d) { return d.y; });
function draw_path(path_data, type) {
var gradient_path = gradient_path_g.selectAll(type)
.data(path_data)
.enter()
.append("path")
.attr("d", line_function(path_data.slice(0,1)))
.attr("class", type)
.attr("stroke-width", 3)
.attr("fill", "none")
.attr("stroke-opacity", 0.5)
.transition()
.duration(drawing_time)
.delay(function(d,i) { return drawing_time * i; })
.attr("d", function(d,i) { return line_function(path_data.slice(0,i+1));})
.remove();
gradient_path_g.append("path")
.attr("d", line_function(path_data))
.attr("class", type)
.attr("stroke-width", 3)
.attr("fill", "none")
.attr("stroke-opacity", 0.5)
.attr("stroke-opacity", 0)
.transition()
.duration(path_data.length * drawing_time)
.attr("stroke-opacity", 0.5);
}
/*
* Start minimization from click on contour map
*/
function mousedown() {
/* Get initial point */
var point = d3.mouse(this);
/* Minimize and draw paths */
minimize(scale_x(point[0]), scale_y(point[1]));
}
function minimize(x0,y0) {
gradient_path_g.selectAll("path").remove();
if (draw_bool.SGD) {
var sgd_data = get_sgd_path(x0, y0, 2e-2, 500);
draw_path(sgd_data, "sgd");
}
if (draw_bool.Momentum) {
var momentum_data = get_momentum_path(x0, y0, 1e-2, 200, 0.8);
draw_path(momentum_data, "momentum");
}
if (draw_bool.RMSProp) {
var rmsprop_data = get_rmsprop_path(x0, y0, 1e-2, 300, 0.99, 1e-6);
draw_path(rmsprop_data, "rmsprop");
}
if (draw_bool.Adam) {
var adam_data = get_adam_path(x0, y0, 1e-2, 100, 0.7, 0.999, 1e-6);
draw_path(adam_data, "adam");
}
}
</script>
This Pen doesn't use any external CSS resources.
This Pen doesn't use any external JavaScript resources.