'use strict'; // 学習パターンを識別関数に適用して誤識別した場合、パーセプトロンの学習規則に基づき重みを調整する // ws : 拡張重みベクトル // xs : 拡張特徴ベクトル // c : クラス (0 or 1) // rho: 学習係数ρ function learn(ws, xs, c, rho) { var g = dot(ws, xs); // 識別関数gを計算 if (c == 0 && g <= 0) { // クラスω1を誤識別 madd(ws, rho, xs); // w' = w + ρx return true; } else if (c != 0 && g >= 0) { // クラスω2を誤識別 madd(ws, -rho, xs); // w' = w - ρx return true; } return false; } // 内積 function dot(a, b) { var l = 0; for (var i = 0; i < a.length; ++i) l += a[i] * b[i]; return l; } // 係数付き加算 function madd(a, p, b) { for (var i = 0; i < a.length; ++i) a[i] += p * b[i]; } // 学習パターン var dataset = [ [0, 1.2], [0, 0.2], [0, -0.2], [1, -0.5], [1, -1.0], [1, -1.5] ]; // 拡張重みベクトル var ws = [-7, 2]; //[11, 5]; // D3.js関連 var width = 600, height = 600, padding = 30; // 重みベクトルの軌跡 var plots = []; // 原点からの距離 function dist(x, y) { return Math.sqrt(x * x + y * y); } // 描画 function draw() { d3.select('svg').remove(); var svg = d3.select('body').append('svg') .attr('width', width) .attr('height', height); var xScale = d3.scale.linear() .domain([-12, 12]) .range([padding, width - padding]); var yScale = d3.scale.linear() .domain([-12, 12]) .range([height - padding, padding]); // パターンの座標を計算 function patternX(d1, r) { return xScale(r / dist(d1, 1)); } function patternY(d1, r) { return yScale(-r / dist(d1, 1) * d1); } var lines, nodes; // パターンをドラッグ時のイベント var dragPattern = (function() { var dragX, dragY; return d3.behavior.drag() .on('dragstart', function(d, i) { d3.event.sourceEvent.stopPropagation(); dragX = xScale(13 / dist(d[1], 1)); dragY = yScale(-13 / dist(d[1], 1) * d[1]); }) .on('drag', function(d, i) { dragX += d3.event.dx; dragY += d3.event.dy; var x = xScale.invert(dragX); var y = yScale.invert(dragY); d[1] = -y / x; d3.select(nodes[0][i]) .attr('transform', function (d) { return 'translate(' + patternX(d[1], 13) + ',' + patternY(d[1], 13) + ')'; }); d3.select(lines[0][i]) .attr('x1', function (d) { return patternX(d[1], -12); }) .attr('y1', function (d) { return patternY(d[1], -12); }) .attr('x2', function (d) { return patternX(d[1], 12); }) .attr('y2', function (d) { return patternY(d[1], 12); }) }); })(); // 学習パターンを示す線 lines = svg.selectAll('line') .data(dataset) .enter() .append('line') .attr('x1', function (d) { return patternX(d[1], -12); }) .attr('y1', function (d) { return patternY(d[1], -12); }) .attr('x2', function (d) { return patternX(d[1], 12); }) .attr('y2', function (d) { return patternY(d[1], 12); }) .attr('stroke-width', 2) .attr('stroke', function (d) { return d[0] == 0 ? 'red' : 'blue'; }); // 学習パターンの番号を表示 nodes = svg.append('g') .attr('class', 'nodes') .selectAll('circle') .data(dataset) .enter() .append('g') .attr('transform', function (d) { return 'translate(' + patternX(d[1], 13) + ',' + patternY(d[1], 13) + ')'; }) .call(dragPattern); nodes.append('circle') .attr('class', 'draggable') .attr('r', 10) .attr('fill', function (d) { return d[0] == 0 ? 'red' : 'blue'; }); nodes.append('text') .text(function (_, i) { return i + 1; }) .attr('class', 'draggable') .attr('font-size', '11px') .attr('dy', '.35em') .attr('text-anchor', 'middle') .attr('fill', 'white'); // 拡張重みベクトルの履歴表示 svg.selectAll('circle.plot') .data(plots) .enter() .append('circle') .attr('cx', function (d) { return xScale(d[0]); }) .attr('cy', function (d) { return yScale(d[1]); }) .attr('r', 3) .attr('fill', function (_, i) { return (i < plots.length - 1) ? 'black' : '#0f0'; }); // 軸描画 var xAxis = d3.svg.axis() .scale(xScale) .orient('bottom') .ticks(5); svg.append('g') .attr('class', 'axis') .attr('transform', 'translate(0,' + yScale(0) + ')') .call(xAxis); svg.append('text') .text('ω1') .attr('font-size', '11px') .attr('x', function () { return xScale(12); }) .attr('y', function () { return yScale(0); }) var yAxis = d3.svg.axis() .scale(yScale) .orient('left') .ticks(5); svg.append('g') .attr('class', 'axis') .attr('transform', 'translate(' + xScale(0) + ',0)') .call(yAxis); svg.append('text') .text('ω0') .attr('font-size', '11px') .attr('x', function () { return xScale(0); }) .attr('y', function () { return yScale(12); }) } // 0~n-1までを順に並べた配列を返す function seq(n) { var a = new Array(n); for (var i = 0; i < n; ++i) a[i] = i; return a; } // 配列をシャッフルしてランダムな順番に並べる function shuffle(a) { for (var i = a.length; --i > 0;) { var j = Math.floor(Math.random() * i); var t = a[i]; a[i] = a[j]; a[j] = t; } return a; } // ターミナルに文字列出力 function terminal(s) { d3.select('#terminal').text(s); } // 小数点以下d桁で四捨五入 function round(x, d) { var k = Math.pow(10, 2); return Math.round(x * k) / k; } // 拡張重みベクトルをフォーマット function showWeight() { return round(ws[0], 2) + ', ' + round(ws[1], 2); } // ランダムな順番にパターンを適用して、学習が行われたらその番号を返す function step(rho) { var order = shuffle(seq(dataset.length)); for (var i = 0; i < dataset.length; ++i) { var d = dataset[order[i]]; var c = d[0]; var xs = [1].concat(d.slice(1)); // 拡張特徴ベクトル、0=1, 1~=xi if (learn(ws, xs, c, rho)) { return order[i]; } } // 学習が行われなかった return -1; } // メイン function main() { d3.select('#step').on('click', function () { var rho = parseFloat(d3.select('#rho').node().value); if (isNaN(rho)) { terminal('Cannot parse rho as float'); return; } var i = step(rho); if (i >= 0) { // 学習が行われた plots.push([ws[1], ws[0]]); draw(); terminal((i + 1) + ': ' + showWeight()); } else { terminal('Done: ' + showWeight() + ': ' + round(-ws[0] / ws[1], 2)); d3.select('#step').node().disabled = true; } }); d3.select('#reset').on('click', function () { plots.length = 0; do { ws[0] = Math.random() * 10 - 5; ws[1] = Math.random() * 10 - 5; } while (ws[0] == 0 && ws[1] == 0); plots.push([ws[1], ws[0]]); draw(); terminal('Reset: ' + showWeight()); d3.select('#step').node().disabled = false; }); plots.push([ws[1], ws[0]]); draw(); } main();
<div>ρ: <input id="rho" type="text" value="1.2" size="5"> <button id="step">step</button> <button id="reset">reset</button> </div> <div style="background-color: #eee; font-family: monospace"> <span id="terminal"> Ready </span> </div>
body { background-color: white; } .axis path, .axis line { fill: none; stroke: black; shape-rendering: crispEdges; } .axis text { font-family: sans-serif; font-size: 11px; } .draggable { cursor: pointer; }