前回
んで、クラスタリングを理解しようとしました。
wikipediaを参照
k-meansってのでやるのはわかった。じゃあどんな計算しているのかな?
なるほど!わからん!(^ω^
本を読んで見る
この本を買ってみて一通り読んでみる。
なるほど!わからん(^ω^
難しい。。。
前回用意したコードを読む
ならばプログラムを読めば良いのではないでしょうか。前回のコードを読んでみる。
km = MiniBatchKMeans( n_clusters = self.num_clusters, init = 'k-means++', batch_size = 1000, n_init = 10, max_no_improvement = 10, verbose = True )
抽象化されててよめまてーん(^ω^
試すときパッと使えるので、すごく良い抽象なのですが、理解するときにはキツイですね。これじゃ何やっているかわかりません。
ということで他のコードを探してみる。
見つけました。D3.js
というライブラリを使ってヴィジュアル化までされている!これは良い。
こんな感じで用意しました。
index.html
<!DOCTYPE html> <html lang="ja"> <head> <meta charset="UTF-8"> <title>Document</title> <script src="https://code.jquery.com/jquery-1.11.3.min.js"></script> </head> <style> .fl { float: left; } .fr { float: right; } #kmeans { width: 840px; height: 500px; } #viewer { width: 500px; height: 500px; } .fieldset { display: inline; margin: .8em 0 1em 0; border: 1px solid #999; padding: .5em; width: 100px; } </style> <body> <div id="kmeans"> <div class="fl"> <svg id="viewer"></svg> </div> <div class="fl"> <div> <button id="step">ステップ</button> <button id="restart" disabled>最初から</button> </div> <fieldset class="fieldset"> <div> <label for="N">N (ノード数):</label> <input type="number" id="N" min="2" max="1000" value="100"> </div> <div> <label for="K">K (クラスター数):</label> <input type="number" id="K" min="2" max="50" value="5"> </div> <div> <button id="reset">新規作成</button> </div> </fieldset> </div> <div clear="all"></div> </div> <script src="http://d3js.org/d3.v3.min.js" charset="utf-8"></script> <script src="./k-means.js"></script> </body> </html> </code>
k-means.js
var flag = false; var WIDTH = d3.select("#viewer")[0][0].offsetWidth - 20; var HEIGHT = Math.max(300, WIDTH * .7); var svg = d3.select("svg#viewer") .attr('width', WIDTH) .attr('height', HEIGHT) .style('padding', '10px') .style('background', '#223344') .style('cursor', 'pointer') .style('-webkit-user-select', 'none') .style('-khtml-user-select', 'none') .style('-moz-user-select', 'none') .style('-ms-user-select', 'none') .style('user-select', 'none') .on('click', function() { d3.event.preventDefault(); step(); } ); d3.selectAll("#kmeans button").style('padding', '.5em .8em'); // add padding d3.selectAll("#kmeans label") // add inline-block, width .style('display', 'inline-block') .style('width', '15em'); var lineg = svg.append('g'); var dotg = svg.append('g'); var centerg = svg.append('g'); /** * step */ var stepFlag = true; var latist; d3.select("#step").on('click', function() { if (stepFlag) { stepFlag = false; // 2重クリック防止 step(); latist = $.when(draw()); } latist.done(function() { stepFlag = true; }) }); /** * 再生成 */ d3.select("#restart").on('click', function() { restart(); draw(); }); /** * 初期生成 */ d3.select("#reset").on('click', function() { init(); draw(); }); var groups = [], dots = []; /** * stepする * 流れとしては * 1.クラスタとノードの紐付けを更新 * 2.クラスタを移動 * 1.2を計算できなくなるまで(?)繰り返す */ function step() { d3.select("#restart").attr("disabled", null); // restart buttonを押せるようにする if (flag) { moveCenter(); draw(); } else { updateGroups(); draw(); } flag = !flag; } /** * 生成する */ function init() { d3.select("#restart").attr("disabled", "disabled"); var N = parseInt(d3.select('#N')[0][0].value, 10); var K = parseInt(d3.select('#K')[0][0].value, 10); groups = []; // クラスタを生成 for (var i = 0; i < K; i++) { var g = { dots : [], color : 'hsl(' + (i * 360 / K) + ',100%,50%)', center: { x: Math.random() * WIDTH, y: Math.random() * HEIGHT }, init : { center: {} } }; g.init.center = { x: g.center.x, y: g.center.y }; groups.push(g); } dots = []; flag = false; // ノードを生成 for (i = 0; i < N; i++) { var dot = { x: Math.random() * WIDTH, y: Math.random() * HEIGHT, group: undefined }; dot.init = { x: dot.x, y: dot.y, group: dot.group }; dots.push(dot); } } /** * 再生成 */ function restart() { flag = false; d3.select("#restart").attr("disabled", "disabled"); groups.forEach(function(g) { g.dots = []; g.center.x = g.init.center.x; g.center.y = g.init.center.y; }); for (var i = 0; i < dots.length; i++) { var dot = dots[i]; dots[i] = { x: dot.init.x, y: dot.init.y, group: undefined, init: dot.init }; } } /** * 描画 */ function draw() { var circles = dotg.selectAll('circle').data(dots); circles.enter().append('circle'); circles.exit().remove(); circles .transition() .duration(500) .attr('cx', function(d) { return d.x; }) .attr('cy', function(d) { return d.y; }) .attr('fill', function(d) { return d.group ? d.group.color : '#ffffff'; }) .attr('r', 5); if (dots[0].group) { var l = lineg.selectAll('line').data(dots); var updateLine = function(lines) { lines .attr('x1', function(d) { return d.x; }) .attr('y1', function(d) { return d.y; }) .attr('x2', function(d) { return d.group.center.x; }) .attr('y2', function(d) { return d.group.center.y; }) .attr('stroke', function(d) { return d.group.color; }); }; updateLine(l.enter().append('line')); updateLine(l.transition().duration(500)); l.exit().remove(); } else { lineg.selectAll('line').remove(); } var c = centerg.selectAll('path').data(groups); var updateCenters = function(centers) { centers .attr('transform', function(d) { return "translate(" + d.center.x + "," + d.center.y + ") rotate(45)"; }) .attr('fill', function(d,i) { return d.color; }) .attr('stroke', '#aabbcc'); }; c.exit().remove(); // クラスターを表示 updateCenters(c.enter() .append('path') .attr('d', d3.svg.symbol().type('cross')) .attr('stroke', '#aabbcc')); // ノードを表示 updateCenters(c.transition().duration(500));} /** * クラスタの位置を更新する */ function moveCenter() { groups.forEach(function(group, i) { if (group.dots.length == 0) return; // get center of gravity var x = 0, y = 0; group.dots.forEach(function(dot) { x += dot.x; y += dot.y; }); group.center = { x: x / group.dots.length, y: y / group.dots.length }; }); } /** * クラスタとノードの紐付けを行う */ function updateGroups() { groups.forEach(function(g) { g.dots = []; }); dots.forEach(function(dot) { // find the nearest group var min = Infinity; var group; groups.forEach(function(g) { var d = Math.pow(g.center.x - dot.x, 2) + Math.pow(g.center.y - dot.y, 2); if (d < min) { min = d; group = g; } }); // update group group.dots.push(dot); dot.group = group; }); } init(); draw();
コメントは正しいとは思いません。修正があればあとから直します。
とりあえずこれを実行すると動くようになりました。
また画面に描画されるのですごく良いライブラリですね。
そろそろ眠たいのでまとめ
雑なまとめ
1.クラスタの割り振り
2.データを割り振り
3.紐付け
((クラスタの位置.x - ノード位置.x)2 ) + ((クラスタの位置.y - ノード位置.y)2)
これはwikipediaの2の手順。データの元に中心を探す
k-means.js : 235行目
if (d < min) { min = d; group = g; }
4.位置を更新
- ノードの位置情報をすべてインクリメント x、y
- クラスタ.center.x = ノードの位置情報すべて足したx / 紐付いたノードの数 。クラスタ.center.y = ノードの位置情報すべて足したy / 紐付いたノードの数
- 1、2をクラスタの数分繰り返す
個人的なイメージ
ちょっとだけイメージしやすくなったタイミング
ノードを極端に少なくしたらわかりやすいかも。
試しにクラスタを5、ノードを5にしてみた。
紐付けのフェーズ
- ノードとクラスタをランダムに配置
- ノードとクラスタの位置を計算
- すべてのクラスタと比較して、一番近くにいるクラスタを探す -> ノードとクラスタを紐付け
- ノードの数だけ、2と3をループする -> ノードとクラスタが紐付け
真ん中を計算
- クラスタに紐付いてるノードのx情報、y情報を全部足す。allX、allYの出来上がり
その値を紐付いているノードの数で割る -> 一番近いところがわかる(?)
収束まで1〜6を繰り返す
少なくしてわかったこと
- クラスタとノードは紐付かない場合がある
- 計算できなくなる時は、真ん中に移動したら終わり(収束)
参考資料
ソースパクリリスペクトしました。ありがとうございます。