Cost function - SilverQ/dl_study GitHub Wiki

Regression Analysis(Wiki)

  • ์šฐ๋ฆฌ๋Š” ์ฃผ์‹ ์˜ˆ์ธก, ์ข…๋ฅ˜ ํŒ๋‹จ ๋“ฑ ์•Œ๋ ค์ง€์ง€ ์•Š์€ ๊ฐ์ฒด๋ฅผ ์‹๋ณ„ํ•˜๊ฑฐ๋‚˜ ๋ฏธ๋ž˜๋ฅผ ์˜ˆ์ธกํ•ด์•ผํ•˜๋Š” ์ผ์ด ๋งŽ๋‹ค.
  • ์šฐ๋ฆฌ๋Š” ์ด๋Ÿฐ ๊ฒฝ์šฐ ๊ด€์ธก ๊ฐ€๋Šฅํ•˜๊ฑฐ๋‚˜ ๊ณผ๊ฑฐ์— ์•Œ๋ ค์ง„ ์ •๋ณด๋ฅผ ํ† ๋Œ€๋กœ ํŒ๋‹จ์„ ํ•˜๊ฒŒ๋œ๋‹ค.
  • ๋•Œ๋กœ๋Š” "ํŠน์ • ์กฐ๊ฑด"์„ ์ฐพ์•„์•ผ ํ•˜๊ธฐ๋„ ํ•˜๋ฉฐ, ๋•Œ๋กœ๋Š” ์ˆ˜ํ•™์ ์ธ ๊ด€๊ณ„๊ฐ€ ์กด์žฌํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ๋งŒ์•ฝ ์ˆ˜ํ•™์ ์ธ ๊ด€๊ณ„๋ฅผ ์ฐพ์•„๋‚ผ ์ˆ˜ ์žˆ๋‹ค๋ฉด?
  • ๊ฒฐ๊ณผ๋ฅผ ์•Œ์ง€ ๋ชปํ•˜๋Š” ์ž…๋ ฅ์— ๋Œ€ํ•ด์„œ๋„ ๊ฒฐ๊ณผ๋ฅผ ์˜ˆ์ธกํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

์„ ํ˜• ํšŒ๊ท€

  • ์•ž์— ์„œ์ˆ ํ•œ ์ž…๋ ฅ-๊ฒฐ๊ณผ์˜ ์ˆ˜ํ•™์  ๊ด€๊ณ„๋ฅผ ํƒ์ƒ‰ํ•˜๊ธฐ ์œ„ํ•ด, ์šฐ๋ฆฌ๋Š” ๋ชจ์ข…์˜ ์„ ํ˜• ๋ฐฉ์ •์‹์„ ๊ฐ€์ •ํ•œ๋‹ค.
  • ๊ฐ€์„ค : hx = ax + b
  • ํ’€์–ด์„œ ์„ค๋ช…ํ•˜๋ฉด ๊ฒฐ๊ณผ์น˜ hx๋Š” ์ž…๋ ฅ x์— ์ƒ์ˆ˜ a์˜ ๊ธฐ์šธ๊ธฐ๋งŒํผ ๋น„๋ก€ํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค. ์ž…๋ ฅ์ด ์—†๋‹ค๋ฉด ๊ฒฐ๊ณผ๋Š” ์ƒ์ˆ˜ b๊ฐ€ ๋‚˜์˜จ๋‹ค.
  • ์ง€๊ธˆ์€ 1์ฐจ ๋ฐฉ์ •์‹ ์ฆ‰ ์ง์„ ์œผ๋กœ ๊ฐ€์ •ํ•œ ๊ฒƒ์ด๊ณ , ์•Œ๋ ค์ง„ ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๋ฅผ ๊ณ ๋ คํ•˜์—ฌ 2์ฐจ, 3์ฐจ ๋ฐฉ์ •์‹๋„ ์‚ฌ์šฉ์ด ๊ฐ€๋Šฅํ•˜๋‹ค.
  • ์šฐ๋ฆฌ๋Š” ์ด ์ง์„ ์˜ ๊ธฐ์šธ๊ธฐ์™€ y์ ˆํŽธ์„ ์•Œ๋ฉด ์ˆ˜์‹์„ ์•Œ์•„๋‚ผ ์ˆ˜ ์žˆ๋‹ค.
  • ์ด์ฒ˜๋Ÿผ ๊ฐ€์„ค์„ ์„ ํ˜• ๋ฐฉ์ •์‹์œผ๋กœ ์„ธ์šฐ๊ณ  ๋ฐ์ดํ„ฐ๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ๋ชจ๋ธ์„ ํƒ์ƒ‰ํ•˜๋Š” ๊ฒƒ์„ ์„ ํ˜• ํšŒ๊ท€๋ผ ํ•œ๋‹ค.

์„ ํ˜• ํšŒ๊ท€์—์„œ์˜ ๋น„์šฉํ•จ์ˆ˜

  • ์ด์ œ ํ•„์š”ํ•œ ๊ฒƒ์€ ์œ„์˜ hx๊ฐ€ ์‹ค์ œ y๊ฐ’๊ณผ ์–ผ๋งˆ๋‚˜ ์ฐจ์ด๊ฐ€ ๋‚˜๋Š”์ง€ ์•Œ์•„๋‚ด๋Š” ๊ฒƒ์ด๊ณ ,
  • ๊ถ๊ทน์ ์œผ๋กœ๋Š” ๊ทธ ์ฐจ์ด๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•œ, ๊ฐ€์žฅ ์ค„์ผ ์ˆ˜ ์žˆ๋Š” ๊ธฐ์šธ๊ธฐ์™€ y์ ˆํŽธ์„ ์ฐพ๋Š” ๊ฒƒ์ด๋‹ค.
  • ๊ทธ๋ ‡๋‹ค๋ฉด ์•Œ๊ณ  ์žˆ๋Š” ๊ฐ’๊ณผ ์˜ˆ์ธกํ•œ ๊ฐ’์˜ ์ฐจ์ด๋Š” ๊ฐ€์„ค(hypothesis)์— ๊ด€ํ•œ ํ•จ์ˆ˜๋กœ ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ๋‹ค.
  • ๊ธฐ์šธ๊ธฐ์™€ y์ ˆํŽธ์˜ ๋ณ€ํ™”์— ๋”ฐ๋ผ ๊ฐ’์˜ ์ฐจ์ด๊ฐ€ ์ปค์ง€๊ฑฐ๋‚˜ ์ž‘์•„์ง€๋Š” ํ•จ์ˆ˜๋กœ ๋ณผ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์šฐ๋ฆฌ๋Š” ์ด๋ฅผ ๋น„์šฉ ํ•จ์ˆ˜๋ผ ๋ถ€๋ฅธ๋‹ค.
  • ์ˆ˜์‹

์˜ˆ์ œ

  • x์™€ y๊ฐ€ ๋™์ผํ•œ ์˜ˆ์ œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ธฐ์šธ๊ธฐ w์— ์˜ํ•œ cost ๋ณ€ํ™”๋ฅผ ๊ด€์ฐฐํ•œ๋‹ค.
  1. x = [1, 2, 3]
  2. y = [1, 2, 3]
def cost(x, y, w):
    c = 0
    for i in range(len(x)):
        hx = w * x[i]
        c += (hx - y[i]) ** 2
    return c / len(x)
x = [1, 2, 3]
y = [1, 2, 3]
for i in range(-30, 51):
    w = i / 10
    c = cost(x, y, w)
    plt.plot(w, c, 'ro')
  • ์‹คํ–‰๊ฒฐ๊ณผ
  1. cost ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜๊ณ  ๊ธฐ์šธ๊ธฐ๋ฅผ -30์—์„œ 50๊นŒ์ง€ ๋ณ€ํ™”์‹œ์ผœ๊ฐ€๋ฉฐ ๊ด€์ฐฐ
  2. ์ฃผ์–ด์ง„ x์™€ y์˜ ๊ด€๊ณ„๋ฅผ ๊ฐ€์žฅ ์ž˜ ๋‚˜ํƒ€๋‚ด๋Š” ๊ฒƒ์€ y=x ์ฆ‰ ๊ธฐ์šธ๊ธฐ=1์ธ ๊ฐ€์„คํ•จ์ˆ˜์ด๋‹ค.
  3. ๊ทธ ๊ฒฐ๊ณผ w=1์ธ ์ง€์ ์˜ cost๊ฐ€ ๊ฐ€์žฅ ๋‚ฎ์€ ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
  4. ๋‹ค์Œ์œผ๋กœ๋Š” ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ cost๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” w๊ฐ’์„ ์ฐพ๊ณ ์ž ํ•œ๋‹ค.
  5. cost