Teachers open the door but You must enter by yourself.

Open Media Lab.
オープンメディアラボ

【事前学習】前回までの内容を再確認しておきましょう。

クラス分類問題
Classification

3クラス分類の例題

アヤメの特徴量を示す4次元データ(がく片の長さと幅、花びらの長さと幅)を学習させてみましょう。


<!doctype html>
<html lang="ja">
<head>
<meta charset="utf-8" />
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.13.0/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@latest"></script>
<script>
async function trainModel(model, inputs, labels){
	model.compile({
		optimizer: tf.train.sgd(0.01), //tf.train.adam()
		loss: tf.losses.meanSquaredError,
		metrics: ['mse'],
	});

	const batchSize = 32;
	const epochs = 100;//学習の反復回数

	return await model.fit(inputs, labels,{
		batchSize,
		epochs,
		shuffle:true,
		callbacks: tfvis.show.fitCallbacks(
			document.getElementById('training'),
			['mse'], 
			{width:350, height:200, callbacks:['onEpochEnd']}
		)
	});
}

async function run() {
	//学習データ
	const x = [[5.1,3.5,1.4,0.2],[4.9,3.0,1.4,0.2],[4.7,3.2,1.3,0.2],[4.6,3.1,1.5,0.2],[5.0,3.6,1.4,0.2],[5.4,3.9,1.7,0.4],[4.6,3.4,1.4,0.3],[5.0,3.4,1.5,0.2],[4.4,2.9,1.4,0.2],[4.9,3.1,1.5,0.1],[5.4,3.7,1.5,0.2],[4.8,3.4,1.6,0.2],[4.8,3.0,1.4,0.1],[4.3,3.0,1.1,0.1],[5.8,4.0,1.2,0.2],[5.7,4.4,1.5,0.4],[5.4,3.9,1.3,0.4],[5.1,3.5,1.4,0.3],[5.7,3.8,1.7,0.3],[5.1,3.8,1.5,0.3],[5.4,3.4,1.7,0.2],[5.1,3.7,1.5,0.4],[4.6,3.6,1.0,0.2],[5.1,3.3,1.7,0.5],[4.8,3.4,1.9,0.2],[7.0,3.2,4.7,1.4],[6.4,3.2,4.5,1.5],[6.9,3.1,4.9,1.5],[5.5,2.3,4.0,1.3],[6.5,2.8,4.6,1.5],[5.7,2.8,4.5,1.3],[6.3,3.3,4.7,1.6],[4.9,2.4,3.3,1.0],[6.6,2.9,4.6,1.3],[5.2,2.7,3.9,1.4],[5.0,2.0,3.5,1.0],[5.9,3.0,4.2,1.5],[6.0,2.2,4.0,1.0],[6.1,2.9,4.7,1.4],[5.6,2.9,3.6,1.3],[6.7,3.1,4.4,1.4],[5.6,3.0,4.5,1.5],[5.8,2.7,4.1,1.0],[6.2,2.2,4.5,1.5],[5.6,2.5,3.9,1.1],[5.9,3.2,4.8,1.8],[6.1,2.8,4.0,1.3],[6.3,2.5,4.9,1.5],[6.1,2.8,4.7,1.2],[6.4,2.9,4.3,1.3],[6.3,3.3,6.0,2.5],[5.8,2.7,5.1,1.9],[7.1,3.0,5.9,2.1],[6.3,2.9,5.6,1.8],[6.5,3.0,5.8,2.2],[7.6,3.0,6.6,2.1],[4.9,2.5,4.5,1.7],[7.3,2.9,6.3,1.8],[6.7,2.5,5.8,1.8],[7.2,3.6,6.1,2.5],[6.5,3.2,5.1,2.0],[6.4,2.7,5.3,1.9],[6.8,3.0,5.5,2.1],[5.7,2.5,5.0,2.0],[5.8,2.8,5.1,2.4],[6.4,3.2,5.3,2.3],[6.5,3.0,5.5,1.8],[7.7,3.8,6.7,2.2],[7.7,2.6,6.9,2.3],[6.0,2.2,5.0,1.5],[6.9,3.2,5.7,2.3],[5.6,2.8,4.9,2.0],[7.7,2.8,6.7,2.0],[6.3,2.7,4.9,1.8],[6.7,3.3,5.7,2.1]];
	//未知データ(学習に未使用)
	const x2 = [[5.0,3.0,1.6,0.2],[5.0,3.4,1.6,0.4],[5.2,3.5,1.5,0.2],[5.2,3.4,1.4,0.2],[4.7,3.2,1.6,0.2],[4.8,3.1,1.6,0.2],[5.4,3.4,1.5,0.4],[5.2,4.1,1.5,0.1],[5.5,4.2,1.4,0.2],[4.9,3.1,1.5,0.1],[5.0,3.2,1.2,0.2],[5.5,3.5,1.3,0.2],[4.9,3.1,1.5,0.1],[4.4,3.0,1.3,0.2],[5.1,3.4,1.5,0.2],[5.0,3.5,1.3,0.3],[4.5,2.3,1.3,0.3],[4.4,3.2,1.3,0.2],[5.0,3.5,1.6,0.6],[5.1,3.8,1.9,0.4],[4.8,3.0,1.4,0.3],[5.1,3.8,1.6,0.2],[4.6,3.2,1.4,0.2],[5.3,3.7,1.5,0.2],[5.0,3.3,1.4,0.2],[6.6,3.0,4.4,1.4],[6.8,2.8,4.8,1.4],[6.7,3.0,5.0,1.7],[6.0,2.9,4.5,1.5],[5.7,2.6,3.5,1.0],[5.5,2.4,3.8,1.1],[5.5,2.4,3.7,1.0],[5.8,2.7,3.9,1.2],[6.0,2.7,5.1,1.6],[5.4,3.0,4.5,1.5],[6.0,3.4,4.5,1.6],[6.7,3.1,4.7,1.5],[6.3,2.3,4.4,1.3],[5.6,3.0,4.1,1.3],[5.5,2.5,4.0,1.3],[5.5,2.6,4.4,1.2],[6.1,3.0,4.6,1.4],[5.8,2.6,4.0,1.2],[5.0,2.3,3.3,1.0],[5.6,2.7,4.2,1.3],[5.7,3.0,4.2,1.2],[5.7,2.9,4.2,1.3],[6.2,2.9,4.3,1.3],[5.1,2.5,3.0,1.1],[5.7,2.8,4.1,1.3],[7.2,3.2,6.0,1.8],[6.2,2.8,4.8,1.8],[6.1,3.0,4.9,1.8],[6.4,2.8,5.6,2.1],[7.2,3.0,5.8,1.6],[7.4,2.8,6.1,1.9],[7.9,3.8,6.4,2.0],[6.4,2.8,5.6,2.2],[6.3,2.8,5.1,1.5],[6.1,2.6,5.6,1.4],[7.7,3.0,6.1,2.3],[6.3,3.4,5.6,2.4],[6.4,3.1,5.5,1.8],[6.0,3.0,4.8,1.8],[6.9,3.1,5.4,2.1],[6.7,3.1,5.6,2.4],[6.9,3.1,5.1,2.3],[5.8,2.7,5.1,1.9],[6.8,3.2,5.9,2.3],[6.7,3.3,5.7,2.5],[6.7,3.0,5.2,2.3],[6.3,2.5,5.0,1.9],[6.5,3.0,5.2,2.0],[6.2,3.4,5.4,2.3],[5.9,3.0,5.1,1.8]];
	const y = Array(75).fill().map((yi,i)=>yi=i<25?[1,0,0]:i<50?[0,1,0]:[0,0,1]);
	//y[0]~y[24]=[1,0,0], y[25]~y[49]=[0,1,0], y[50]~y[74]=[0,0,1]
	const xs = tf.tensor2d(x,[x.length,4]);
	const ys = tf.tensor2d(y,[y.length,3]);
	
	const model=tf.sequential();
	const layer=tf.layers.dense({inputShape:[4], units:3});
	model.add(layer);
	
	await trainModel(model, xs, ys);

	document.getElementById('output').innerHTML
		= '<pre>'+model.predict(xs).toString()+'</pre>';
	document.getElementById('weight').innerHTML
		= '<pre>'+layer.getWeights().toString()
			.replace(',Tensor',',<br>Tensor')+'</pre>';
	document.getElementById('loss').innerHTML
		= '<pre>'+model.evaluate(xs,ys).toString()
			.replace(',Tensor',',<br>Tensor')+'</pre>';

	const setosa=[];
	const versicolor=[];
	const virginica=[];
	y.map((yi, i) =>{
		const p={x:x[i][2], y:x[i][3]};
		if(yi[0]==1) setosa.push(p);
		else if(yi[1]==1) versicolor.push(p);
		else virginica.push(p);
	});

	tfvis.render.scatterplot(
		document.getElementById('chart'),
		{values:[versicolor,virginica,setosa], series:['_versicolor','_virginica','setosa']},
		{xLabel:'petallength', yLabel:'petalwidth', width:350, height:400, seriesColors:['green','blue','red']}
	);

	const setosa2=[];
	const versicolor2=[];
	const virginica2=[];
	const correct=[];
	model.predict(xs).arraySync().map((yi, i) =>{
		const p={x:x[i][2], y:x[i][3]};
		if(y[i][0]==1){
			if(yi[0]>yi[1]&&yi[0]>yi[2]) correct.push(p);
			else setosa2.push(p);
		}else if(y[i][1]==1){
			if(yi[1]>yi[0]&&yi[1]>yi[2]) correct.push(p);
			else versicolor2.push(p);
		}else{
			if(yi[2]>yi[0]&&yi[2]>yi[1]) correct.push(p);
			else virginica2.push(p);
		}
	});
	tfvis.render.scatterplot(
		document.getElementById('chart2'),
		{values:[correct,versicolor2,virginica2,setosa2], series:[' correct','_versicolor','_virginica','setosa']},
		{xLabel:'petallength', yLabel:'petalwidth', width:350, height:400, seriesColors:['gray','green','blue','red']}
	);
	document.getElementById('score').innerHTML
	= '<p>'+(correct.length/0.75).toFixed(1).toString()+'%(学習データ)</p>';
}
document.addEventListener('DOMContentLoaded', run);
</script>
</head>
<body>
<h3>ニューラルネットの学習結果</h3>
<div id="training"></div>
<h4>2次元(花びらの長さと幅)の分布</h4>
<div id="chart"></div>
<h4>判別を誤ったデータ</h4>
<div id="chart2"></div>
<p><small>※ 共通する種類を同じ色で表しています。</small></p>
<h4>正答率</h4>
<div id="score"></div>
<h4>学習後のネットの出力値</h4>
<div id="output"></div>
<h4>第1層のパラメータ</h4>
<div id="weight"></div>
<h4>2乗誤差の平均 (MSE)</h4>
<div id="loss"></div>
<p><small>※ 結果が表示されるまでしばらくお待ちください。</small></p>
</body>
</html>

ニューラルネットの学習結果

2次元(花びらの長さと幅)の分布

判別を誤ったデータ

※ 共通する種類を同じ色で表しています。

正答率

学習後のネットの出力値

第1層のパラメータ

2乗誤差の平均 (MSE)

※ 結果が表示されるまでしばらくお待ちください。

Lesson

  1. 出力層に非線形の活性化関数を設定してみましょう。
  2. 
    const layer=tf.layers.dense({inputShape:[4], units:3, activation:'sigmoid'});
    model.add(layer);
    

    ※ optimizer: tf.train.adam() に変更。反復回数は1000回

  3. 3層ネットワークで学習させてみましょう。
    
    const layer=tf.layers.dense({inputShape:[4], units:2, activation:'sigmoid'});
    model.add(layer);
    const layer2=tf.layers.dense({units:3, activation:'sigmoid'});
    model.add(layer2);
    

    ※ optimizer: tf.train.adam() に変更。反復回数は2000回

  4. 前問迄の各場合について、半分のデータで学習を行い、残り半分のデータの正答率を求めてみましょう。2つ目のグラフの入力を未知データに変更する。
    
    const xs = tf.tensor2d(x,[x.length,4]);
    const xs2 = tf.tensor2d(x2,[x2.length,4]);
    const ys = tf.tensor2d(y,[y.length,3]);
    	.
    	.
    	.
    const virginica2=[];
    const correct=[];
    model.predict(xs2).arraySync().map((yi, i) =>{
    	const p={x:x2[i][2], y:x2[i][3]};
    

【事後学習】本日学んだ内容を再確認しておきましょう。

This site is powered by Powered by MathJax