Tutorial วันนี้เรามาอธิบาย concept ของ Logistic Regression เบื้องต้น พร้อมโค้ดตัวอย่างใน R สำหรับสร้างและทดสอบโมเดล – Case Study ทำนายการเกิดมะเร็งเต้านม (Breast Cancer Dataset) Show When to use?เรานิยมใช้ Logistic Regression กับปัญหา Binary Classification i.e. ทำนาย target variable ที่มีสอง classes และใช้ค่า % accuracy สำหรับวัดผลโมเดลเบื้องต้น ด้านล่างเป็นตัวอย่าง use cases ในชีวิตจริง
Key Conceptโมเดลนี้ง่ายกว่าที่คิด !! เพราะ Logistic Regression จริงๆคือ Extended Version ของ Linear Regression รูปด้านล่างแอดลอง plot กราฟขึ้นมาจะเห็นว่าแกนตั้ง y มีได้สองค่าคือ {0, 1} ส่วนแกนนอนคือตัวแปร X1 สำหรับข้อมูลประเภทนี้ ถ้าใช้ Linear Regression ทั่วไป ถามว่าเทรนได้ไหม ก็ทำได้แต่ผลจะออกมาไม่ดี เหตุผลที่เราไม่ใช้ linear regression กับปัญหา {0, 1} แบบนี้คือ [su_highlight][1] linear regression เหมาะกับตัวแปร y แบบ continuous และ [2] ผลทำนายของสมการ linear regression เป็นแบบเส้นตรงมีโอกาสที่จะต่ำกว่าศูนย์หรือสูงกว่าหนึ่ง ซึ่งไม่ตอบโจทย์ binary classification ที่ค่า y ต้องอยู่ในช่วง [0,1] [/su_highlight] แล้วเราจะแก้ปัญหานี้ยังไงดี? นักคณิตศาสตร์เลยคิด Sigmoid Function ขึ้นมาเพื่อใช้ normalize ตัวเลขอะไรก็ได้ให้มีค่าอยู่ระหว่าง [0, 1] สำหรับแก้ปัญหา binary classification โดยเฉพาะ กราฟด้านล่างเราเปลี่ยนสมการเส้นตรงให้กลายเป็น S-Curve ที่ fit กับข้อมูล [0, 1] ได้ดีขึ้นมาก .. What a Cool Trick! ดาวน์โหลดไฟล์ตัวอย่างการเขียน Sigmoid Function ใน Excel ได้ที่นี่ How Sigmoid Works?Sigmoid สามารถเขียนเป็นสมการทางคณิตศาสตร์ได้ตามรูปด้านล่าง โดยที่ e คือ Exponential Function หรือฟังชั่น exp() ใน Excel/ R นักคณิตศาสตร์ใช้ฟังชั่นตระกูล exp ในการเปลี่ยนสมการ linear เป็น non-linear เส้นกราฟ Sigmoid ที่เราเห็นด้านบนเลยกลายเป็น s-curve สวยงาม แฮร่! ค่า Z ในสมการคือค่า weighted sum (เหมือนสมการ linear regression) แต่ Logistic Reression ใช้เทคนิคที่เรียกว่า Maximum Likelihood ในการคำนวณ weights (bo, b1, b2, …) แทนการใช้ Least Squares
ผลลัพธ์ที่ได้จาก sigmoid(Z) คือความน่าจะเป็นที่ y=1 เราสามารถกำหนด threshold สำหรับการทำนายของโมเดลได้ เช่น ถ้า sigmoid(Z) >= 0.5 ให้ทำนาย y=1 (positive) แต่ถ้าน้อยกว่า 0.5 ให้ทำนาย y=0 (negative) ทำไมมันง่ายอย่างงี้ๆๆ ตอนนี้เราสามารถสร้างโมเดลที่ fit กับข้อมูล 0, 1 ได้แล้ว 😛 Implementation in Rโอเคร! ตอนนี้เราเข้าใจ concept เบื้องต้นของ Logistic Regression แล้ว ถัดไปมาลองเขียนโค้ด R กันบ้างโดยโค้ดของเราจะแบ่งเป็น 5 ขั้นตอน Load data → Clean data → Split data → Train model → Test model [1] Load DataTutorial วันนี้เราใช้ข้อมูล Breast Cancer จาก package mlbench นักเรียนสามารถโหลดข้อมูลเข้าสู่ RStudio ด้วยโค้ดด้านล่าง ถ้าใครยังไม่เคยติดตั้ง package นี้ให้รันโค้ด Target variable ที่เราต้องการทำนายคือ “Class” {benign, malignant} โดย positive class ของโมเดลเราคือ malignant (เนื้อร้าย/ เป็นมะเร็ง) ซึ่งมีอยู่ประมาณ 35% ใน dataset นี้
[2] Clean Dataโหลดข้อมูลเสร็จแล้ว นักเรียนสามารถเรียกดู structure เบื้องต้นของ dataframe ด้วยฟังชั่น str() หรือ head() โค้ดด้านล่างเราใช้ฟังชั่น na.omit() เพื่อลบแถวที่มี missing values และลบคอลั่ม Id ด้วยการ assign NULL
[3] Split Dataแบ่งข้อมูลเป็น train 80% และ test 20% อย่าลืม set.seed(1) เพื่อให้ผล random id ของเราสามารถทำซ้ำได้
[4] Train Modelเทรน logistic regression ด้วยฟังชั่น glm() กำหนด family = “binomial” เพราะว่า target ของเรามี 2 classes {benign, malignant} เสร็จแล้วนำโมเดลที่ได้ไปทำนาย train_df และคำนวณค่า train accuracy ถ้ารันโค้ดด้านล่างเสร็จแล้วจะพบว่าค่า train accuracy เท่ากับ 100% ตอนนี้เราต้อง skeptical กับผลที่ได้แล้ว เพราะเป็นสัญญาณหนึ่งของปัญหา overfitting เด๋วเราจะลองนำโมเดลนี้ไปทำนาย test_df ในขั้นตอนต่อไป
Note – accuracy ไม่ใช่ metric เดียวที่เราใช้วัดผลโมเดล binary classification ปกติเราจะดูค่า precision, recall และ F1-score ด้วย อ่านเพิ่มเติมได้ในบทความแนะนำด้านล่าง 10 Metrics พื้นฐานสำหรับวัดผลโมเดล MLรู้จักกับ Accuracy, Precision, Recall, F1-Score สำหรับ Classification Problem [5] Test Modelโค้ดด้านล่างเขียนเหมือนตอนทำนาย train_df แค่เปลี่ยนชื่อข้อมูลทั้งหมดเป็น test_df เราจะได้[su_highlight]ค่า test accuracy อยู่ที่ 91.97% ซึ่งน้อยกว่า train accuracy ประมาณ 9% เราสามารถสรุปได้ทันทีว่าโมเดลที่เราสร้างขึ้นมามีปัญหา overfitting ดูง่ายๆจากผล test accuracy ที่มีค่าน้อยกว่า train accuracy มากๆ[/su_highlight]
แล้วเราจะลดปัญหา overfitting ได้ยังไง? คำตอบอยู่ข้างล่าง Read On! Regularizationเทคนิคสำคัญที่ ML practitioners ใช้ลดปัญหา overfitting เรียกว่า “Regularization” เป็นเทคนิคที่ทำให้โมเดลของเรา simple ขึ้น → generalize ได้ดีขึ้น สำหรับปัญหา regression เทคนิคนี้จะไปปรับค่า coefficients ของโมเดลให้มีขนาดเล็กลง หรือใน extreme cases คือปรับ coefficients เป็นศูนย์เลย Next Time – บทความต่อไปเราจะเขียนอธิบาย regularization ให้นักเรียนอ่านเต็มๆอีกที 🙂 Full R Codeโค้ดตั้งแต่ line 22 เป็นต้นไปใช้สำหรับสร้าง Regularized Logistic Regression (ridge, lasso, elastic net) ด้วย package glmnet และ caret (อีกชื่อหนึ่งของ Regularization คือ Penalized Regression)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters |