Binary logistic regression ค อ ม 2 ต วแปรตาม

Tutorial วันนี้เรามาอธิบาย concept ของ Logistic Regression เบื้องต้น พร้อมโค้ดตัวอย่างใน R สำหรับสร้างและทดสอบโมเดล – Case Study ทำนายการเกิดมะเร็งเต้านม (Breast Cancer Dataset)

When to use?

เรานิยมใช้ Logistic Regression กับปัญหา Binary Classification i.e. ทำนาย target variable ที่มีสอง classes และใช้ค่า % accuracy สำหรับวัดผลโมเดลเบื้องต้น ด้านล่างเป็นตัวอย่าง use cases ในชีวิตจริง

  • Churn prediction – ทำนายว่าลูกค้าจะเลิกใช้บริการหรือเปล่า (yes/ no)
  • Repeated purchase prediction – ทำนายว่าลูกค้าจะกลับมาซื้อสินค้าหรือเปล่า (yes/ no)
  • Disease detection – ทำนายว่าจะเป็นโรคหรือเปล่า (yes/ no)
  • Spam classification – ทำนายว่าอีเมล์เป็น spam หรือเปล่า (yes/ no)
  • (Marketing) Conversion prediction – ทำนายว่า user จะ take action หรือเปล่า (yes/ no)

Key Concept

โมเดลนี้ง่ายกว่าที่คิด !! เพราะ Logistic Regression จริงๆคือ Extended Version ของ Linear Regression รูปด้านล่างแอดลอง plot กราฟขึ้นมาจะเห็นว่าแกนตั้ง y มีได้สองค่าคือ {0, 1} ส่วนแกนนอนคือตัวแปร X1

สำหรับข้อมูลประเภทนี้ ถ้าใช้ Linear Regression ทั่วไป ถามว่าเทรนได้ไหม ก็ทำได้แต่ผลจะออกมาไม่ดี เหตุผลที่เราไม่ใช้ linear regression กับปัญหา {0, 1} แบบนี้คือ [su_highlight][1] linear regression เหมาะกับตัวแปร y แบบ continuous และ [2] ผลทำนายของสมการ linear regression เป็นแบบเส้นตรงมีโอกาสที่จะต่ำกว่าศูนย์หรือสูงกว่าหนึ่ง ซึ่งไม่ตอบโจทย์ binary classification ที่ค่า y ต้องอยู่ในช่วง [0,1] [/su_highlight]

Binary logistic regression ค อ ม 2 ต วแปรตาม

แล้วเราจะแก้ปัญหานี้ยังไงดี? นักคณิตศาสตร์เลยคิด Sigmoid Function ขึ้นมาเพื่อใช้ normalize ตัวเลขอะไรก็ได้ให้มีค่าอยู่ระหว่าง [0, 1] สำหรับแก้ปัญหา binary classification โดยเฉพาะ กราฟด้านล่างเราเปลี่ยนสมการเส้นตรงให้กลายเป็น S-Curve ที่ fit กับข้อมูล [0, 1] ได้ดีขึ้นมาก .. What a Cool Trick!

Binary logistic regression ค อ ม 2 ต วแปรตาม
ใช้ sigmoid function เพื่อ normalize ค่าให้อยู่ระหว่าง [0,1]

ดาวน์โหลดไฟล์ตัวอย่างการเขียน Sigmoid Function ใน Excel ได้ที่นี่

How Sigmoid Works?

Sigmoid สามารถเขียนเป็นสมการทางคณิตศาสตร์ได้ตามรูปด้านล่าง โดยที่ e คือ Exponential Function หรือฟังชั่น exp() ใน Excel/ R นักคณิตศาสตร์ใช้ฟังชั่นตระกูล exp ในการเปลี่ยนสมการ linear เป็น non-linear เส้นกราฟ Sigmoid ที่เราเห็นด้านบนเลยกลายเป็น s-curve สวยงาม แฮร่!

Binary logistic regression ค อ ม 2 ต วแปรตาม

ค่า Z ในสมการคือค่า weighted sum (เหมือนสมการ linear regression) แต่ Logistic Reression ใช้เทคนิคที่เรียกว่า Maximum Likelihood ในการคำนวณ weights (bo, b1, b2, …) แทนการใช้ Least Squares

## weighted sum (just like linear regression)
Z = b0 + b1x1 + b2x2 + b3x3 + b4x4 + ...
## apply sigmoid to Z value
probability_y = exp(Z) / (1+exp(Z))

ผลลัพธ์ที่ได้จาก sigmoid(Z) คือความน่าจะเป็นที่ y=1 เราสามารถกำหนด threshold สำหรับการทำนายของโมเดลได้ เช่น ถ้า sigmoid(Z) >= 0.5 ให้ทำนาย y=1 (positive) แต่ถ้าน้อยกว่า 0.5 ให้ทำนาย y=0 (negative)

ทำไมมันง่ายอย่างงี้ๆๆ ตอนนี้เราสามารถสร้างโมเดลที่ fit กับข้อมูล 0, 1 ได้แล้ว 😛

Implementation in R

โอเคร! ตอนนี้เราเข้าใจ concept เบื้องต้นของ Logistic Regression แล้ว ถัดไปมาลองเขียนโค้ด R กันบ้างโดยโค้ดของเราจะแบ่งเป็น 5 ขั้นตอน Load data → Clean data → Split data → Train model → Test model

[1] Load Data

Tutorial วันนี้เราใช้ข้อมูล Breast Cancer จาก package mlbench นักเรียนสามารถโหลดข้อมูลเข้าสู่ RStudio ด้วยโค้ดด้านล่าง ถ้าใครยังไม่เคยติดตั้ง package นี้ให้รันโค้ด install.packages("mlbench") ก่อน

Target variable ที่เราต้องการทำนายคือ “Class” {benign, malignant} โดย positive class ของโมเดลเราคือ malignant (เนื้อร้าย/ เป็นมะเร็ง) ซึ่งมีอยู่ประมาณ 35% ใน dataset นี้

## install.packages("mlbench")
library(mlbench)
data("BreastCancer")

[2] Clean Data

โหลดข้อมูลเสร็จแล้ว นักเรียนสามารถเรียกดู structure เบื้องต้นของ dataframe ด้วยฟังชั่น str() หรือ head() โค้ดด้านล่างเราใช้ฟังชั่น na.omit() เพื่อลบแถวที่มี missing values และลบคอลั่ม Id ด้วยการ assign NULL

## check if any missing values
mean(complete.cases(BreastCancer)) 
## remove rows with NA
df <- na.omit(BreastCancer) 
## remove column id
df$Id <- NULL 

[3] Split Data

แบ่งข้อมูลเป็น train 80% และ test 20% อย่าลืม set.seed(1) เพื่อให้ผล random id ของเราสามารถทำซ้ำได้

set.seed(1)
id <- sample(1:nrow(df), 0.8*nrow(df))
train_df <- df[id, ]
test_df <- df[-id, ]

[4] Train Model

เทรน logistic regression ด้วยฟังชั่น glm() กำหนด family = “binomial” เพราะว่า target ของเรามี 2 classes {benign, malignant} เสร็จแล้วนำโมเดลที่ได้ไปทำนาย train_df และคำนวณค่า train accuracy

ถ้ารันโค้ดด้านล่างเสร็จแล้วจะพบว่าค่า train accuracy เท่ากับ 100% ตอนนี้เราต้อง skeptical กับผลที่ได้แล้ว เพราะเป็นสัญญาณหนึ่งของปัญหา overfitting เด๋วเราจะลองนำโมเดลนี้ไปทำนาย test_df ในขั้นตอนต่อไป

## Train logistic regression
log_model <- glm(Class ~ ., data = train_df, family = "binomial")
## Predict and evaluate train dataset
p1 <- predict(log_model, type = "response")
p1 <- ifelse(p1 >= .5, T, F)
train_result <- table(p1, train_df$Class)
print(paste0("Train Accuracy: ", sum(diag(train_result)/ nrow(train_df))) )

Note – accuracy ไม่ใช่ metric เดียวที่เราใช้วัดผลโมเดล binary classification ปกติเราจะดูค่า precision, recall และ F1-score ด้วย อ่านเพิ่มเติมได้ในบทความแนะนำด้านล่าง

10 Metrics พื้นฐานสำหรับวัดผลโมเดล ML

รู้จักกับ Accuracy, Precision, Recall, F1-Score สำหรับ Classification Problem

[5] Test Model

โค้ดด้านล่างเขียนเหมือนตอนทำนาย train_df แค่เปลี่ยนชื่อข้อมูลทั้งหมดเป็น test_df เราจะได้[su_highlight]ค่า test accuracy อยู่ที่ 91.97% ซึ่งน้อยกว่า train accuracy ประมาณ 9% เราสามารถสรุปได้ทันทีว่าโมเดลที่เราสร้างขึ้นมามีปัญหา overfitting ดูง่ายๆจากผล test accuracy ที่มีค่าน้อยกว่า train accuracy มากๆ[/su_highlight]

## Predict and evaluate test dataset
p2 <- predict(log_model, newdata = test_df, type = "response")
p2 <- ifelse(p2 >= .5, T, F)
test_result <- table(p2, test_df$Class)
print(paste0("Test Accuracy: ", round(sum(diag(test_result)/ nrow(test_df)), 4)))

แล้วเราจะลดปัญหา overfitting ได้ยังไง? คำตอบอยู่ข้างล่าง Read On!

Regularization

เทคนิคสำคัญที่ ML practitioners ใช้ลดปัญหา overfitting เรียกว่า “Regularization” เป็นเทคนิคที่ทำให้โมเดลของเรา simple ขึ้น → generalize ได้ดีขึ้น สำหรับปัญหา regression เทคนิคนี้จะไปปรับค่า coefficients ของโมเดลให้มีขนาดเล็กลง หรือใน extreme cases คือปรับ coefficients เป็นศูนย์เลย

Next Time – บทความต่อไปเราจะเขียนอธิบาย regularization ให้นักเรียนอ่านเต็มๆอีกที 🙂

Full R Code

โค้ดตั้งแต่ line 22 เป็นต้นไปใช้สำหรับสร้าง Regularized Logistic Regression (ridge, lasso, elastic net) ด้วย package glmnet และ caret (อีกชื่อหนึ่งของ Regularization คือ Penalized Regression)

  • Line 31: Regularization จะมีสอง hyperparameters หลักที่เราจูนได้คือ alpha และ lambda
  • Line 32-33: เทรนโมเดลด้วย Grid Search 5-Fold CV
  • Line 48: Train accuracy 97.25%
  • Line 54: Test accuracy 97.08%

This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters