ITの隊長のブログ

ITの隊長のブログです。Rubyを使って仕事しています。最近も色々やっているお(^ω^ = ^ω^)

機械学習・クラスタリングを理解するまで2日目

スポンサードリンク

前回

aipacommander.hatenablog.jp

んで、クラスタリングを理解しようとしました。

wikipediaを参照

k-meansってのでやるのはわかった。じゃあどんな計算しているのかな?

k平均法 - Wikipedia

なるほど!わからん!(^ω^

本を読んで見る

データサイエンティスト養成読本 機械学習入門編 (Software Design plus)

この本を買ってみて一通り読んでみる。

なるほど!わからん(^ω^

難しい。。。

前回用意したコードを読む

ならばプログラムを読めば良いのではないでしょうか。前回のコードを読んでみる。

      km = MiniBatchKMeans(
        n_clusters         = self.num_clusters,
        init               = 'k-means++',
        batch_size         = 1000,
        n_init             = 10,
        max_no_improvement = 10,
        verbose            = True
        )

抽象化されててよめまてーん(^ω^

試すときパッと使えるので、すごく良い抽象なのですが、理解するときにはキツイですね。これじゃ何やっているかわかりません。

ということで他のコードを探してみる。

見つけました。D3.jsというライブラリを使ってヴィジュアル化までされている!これは良い。

こんな感じで用意しました。

  • index.html
<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <title>Document</title>
    <script src="https://code.jquery.com/jquery-1.11.3.min.js"></script>
</head>
<style>
.fl {
  float: left;
}
.fr {
  float: right;
}
#kmeans {
  width: 840px;
  height: 500px;
}
#viewer {
  width: 500px;
  height: 500px;
}
.fieldset {
  display: inline;
  margin: .8em 0 1em 0;
  border: 1px solid #999;
  padding: .5em;
  width: 100px;
}
</style>
<body>
    <div id="kmeans">
        <div class="fl">
            <svg id="viewer"></svg>
        </div>
        <div class="fl">
          <div>
              <button id="step">ステップ</button>
              <button id="restart" disabled>最初から</button>
          </div>
          <fieldset class="fieldset">
              <div>
                  <label for="N">N (ノード数):</label>
                  <input type="number" id="N" min="2" max="1000" value="100">
              </div>
              <div>
                  <label for="K">K (クラスター数):</label>
                  <input type="number" id="K" min="2" max="50" value="5">
              </div>
              <div>
                  <button id="reset">新規作成</button>
              </div>
          </fieldset>
        </div>
        <div clear="all"></div>
    </div>
    <script src="http://d3js.org/d3.v3.min.js" charset="utf-8"></script>
    <script src="./k-means.js"></script>
</body>

</html>

</code>
  • k-means.js
var flag   = false;
var WIDTH  = d3.select("#viewer")[0][0].offsetWidth - 20; 
var HEIGHT = Math.max(300, WIDTH * .7);
var svg    = d3.select("svg#viewer")
  .attr('width', WIDTH)
  .attr('height', HEIGHT)
  .style('padding', '10px')
  .style('background', '#223344')
  .style('cursor', 'pointer')
  .style('-webkit-user-select', 'none')
  .style('-khtml-user-select', 'none')
  .style('-moz-user-select', 'none')
  .style('-ms-user-select', 'none')
  .style('user-select', 'none')
  .on('click', function() {
    d3.event.preventDefault();
    step();
  }
);

d3.selectAll("#kmeans button").style('padding', '.5em .8em'); // add padding 
d3.selectAll("#kmeans label")                                 // add inline-block, width
  .style('display', 'inline-block')
  .style('width', '15em');

var lineg   = svg.append('g');
var dotg    = svg.append('g');
var centerg = svg.append('g');
/**
 * step
 */
var stepFlag = true;
var latist;
d3.select("#step").on('click', function() {
  if (stepFlag) {
    stepFlag = false; // 2重クリック防止
    step();
    latist   = $.when(draw());
  }
  latist.done(function() {
    stepFlag = true;
  })
});
/**
 * 再生成
 */
d3.select("#restart").on('click', function() {
  restart();
  draw();
});
/**
 * 初期生成
 */
d3.select("#reset").on('click', function() {
  init();
  draw();
});


var groups = [], dots = [];

/**
 * stepする
 * 流れとしては
 * 1.クラスタとノードの紐付けを更新
 * 2.クラスタを移動
 * 1.2を計算できなくなるまで(?)繰り返す
 */
function step() {
  d3.select("#restart").attr("disabled", null); // restart buttonを押せるようにする
  if (flag) {
    moveCenter();
    draw();
  } else {
    updateGroups();
    draw();
  }
  flag = !flag;
}

/**
 * 生成する
 */
function init() {
  d3.select("#restart").attr("disabled", "disabled");
  var N  = parseInt(d3.select('#N')[0][0].value, 10);
  var K  = parseInt(d3.select('#K')[0][0].value, 10);
  groups = [];
  // クラスタを生成
  for (var i = 0; i < K; i++) {
    var g = {
      dots  : [],
      color : 'hsl(' + (i * 360 / K) + ',100%,50%)',
      center: {
        x: Math.random() * WIDTH,
        y: Math.random() * HEIGHT
      },
      init  : {
        center: {}
      }
    };
    g.init.center = {
      x: g.center.x,
      y: g.center.y
    };
    groups.push(g);
  }

  dots = [];
  flag = false;
  // ノードを生成
  for (i = 0; i < N; i++) {
    var dot = {
      x: Math.random() * WIDTH,
      y: Math.random() * HEIGHT,
      group: undefined
    };
    dot.init = {
      x: dot.x,
      y: dot.y,
      group: dot.group
    };
    dots.push(dot);
  }
}

/**
 * 再生成
 */
function restart() {
  flag = false;
  d3.select("#restart").attr("disabled", "disabled");

  groups.forEach(function(g) {
    g.dots     = [];
    g.center.x = g.init.center.x;
    g.center.y = g.init.center.y;
  });

  for (var i = 0; i < dots.length; i++) {
    var dot = dots[i];
    dots[i] = {
      x: dot.init.x,
      y: dot.init.y,
      group: undefined,
      init: dot.init
    };
  }
}


/**
 * 描画
 */
function draw() {
  var circles = dotg.selectAll('circle').data(dots);
  circles.enter().append('circle');
  circles.exit().remove();
  circles
    .transition()
    .duration(500)
    .attr('cx', function(d) { return d.x; })
    .attr('cy', function(d) { return d.y; })
    .attr('fill', function(d) { return d.group ? d.group.color : '#ffffff'; })
    .attr('r', 5);

  if (dots[0].group) {
    var l = lineg.selectAll('line').data(dots);
    var updateLine = function(lines) {
      lines
        .attr('x1', function(d) { return d.x; })
        .attr('y1', function(d) { return d.y; })
        .attr('x2', function(d) { return d.group.center.x; })
        .attr('y2', function(d) { return d.group.center.y; })
        .attr('stroke', function(d) { return d.group.color; });
    };
    updateLine(l.enter().append('line'));
    updateLine(l.transition().duration(500));
    l.exit().remove();
  } else {
    lineg.selectAll('line').remove();
  }

  var c = centerg.selectAll('path').data(groups);
  var updateCenters = function(centers) {
    centers
      .attr('transform', function(d) {
        return "translate(" + d.center.x + "," + d.center.y + ") rotate(45)";
      })
      .attr('fill', function(d,i) {
        return d.color;
      })
      .attr('stroke', '#aabbcc');
  };
  c.exit().remove();
  // クラスターを表示
  updateCenters(c.enter()
    .append('path')
    .attr('d', d3.svg.symbol().type('cross'))
    .attr('stroke', '#aabbcc'));
  // ノードを表示
  updateCenters(c.transition().duration(500));}

/**
 * クラスタの位置を更新する
 */
function moveCenter() {
  groups.forEach(function(group, i) {
    if (group.dots.length == 0) return;
    // get center of gravity
    var x = 0, y = 0;
    group.dots.forEach(function(dot) {
      x += dot.x;
      y += dot.y;
    });

    group.center = {
      x: x / group.dots.length,
      y: y / group.dots.length
    };
  });
}

/**
 * クラスタとノードの紐付けを行う
 */
function updateGroups() {
  groups.forEach(function(g) { g.dots = []; });
  dots.forEach(function(dot) {
    // find the nearest group
    var min = Infinity;
    var group;
    groups.forEach(function(g) {
      var d = Math.pow(g.center.x - dot.x, 2) + Math.pow(g.center.y - dot.y, 2);
      if (d < min) {
        min   = d;
        group = g;
      }
    });
    // update group
    group.dots.push(dot);
    dot.group = group;
  });
}

init();
draw();

コメントは正しいとは思いません。修正があればあとから直します。

とりあえずこれを実行すると動くようになりました。

また画面に描画されるのですごく良いライブラリですね。

そろそろ眠たいのでまとめ

雑なまとめ

1.クラスタの割り振り

2.データを割り振り

3.紐付け

((クラスタの位置.x - ノード位置.x)2 ) + ((クラスタの位置.y - ノード位置.y)2)

これはwikipediaの2の手順。データの元に中心を探す

k-means.js : 235行目

      if (d < min) {
        min   = d;
        group = g;
      }

4.位置を更新

  1. ノードの位置情報をすべてインクリメント x、y
  2. クラスタ.center.x = ノードの位置情報すべて足したx / 紐付いたノードの数 。クラスタ.center.y = ノードの位置情報すべて足したy / 紐付いたノードの数
  3. 1、2をクラスタの数分繰り返す

個人的なイメージ

ちょっとだけイメージしやすくなったタイミング

ノードを極端に少なくしたらわかりやすいかも。

試しにクラスタを5、ノードを5にしてみた。

紐付けのフェーズ

  1. ノードとクラスタをランダムに配置
  2. ノードとクラスタの位置を計算
  3. すべてのクラスタと比較して、一番近くにいるクラスタを探す -> ノードとクラスタを紐付け
  4. ノードの数だけ、2と3をループする -> ノードとクラスタが紐付け

真ん中を計算

  1. クラスタに紐付いてるノードのx情報、y情報を全部足す。allX、allYの出来上がり
  2. その値を紐付いているノードの数で割る -> 一番近いところがわかる(?)

  3. 収束まで1〜6を繰り返す

少なくしてわかったこと

  • クラスタとノードは紐付かない場合がある
  • 計算できなくなる時は、真ん中に移動したら終わり(収束)

参考資料

k平均法 - Wikipedia

tech.nitoyon.com

ソースパクリリスペクトしました。ありがとうございます。