ChatGPTに将来の株価を予想してもらうコードを書いてもらった
ChatGPTに将来の株価を予想してもらうコードを書いてもらいました。私自身がpandas、prophet、各種API等を普段から使っていて、ある程度の予備知識をもとにしている質問なのでズルい感じもしますが(普通だと環境構築でそこそこ手間取ります)。それでも、ChatGPTへの質問一発で、必要最低限、動くコードが出てきました。素晴らしい。
ChatGPTへの指示文 :
東証銘柄の過去データをyahooから取得して、Prophetで予測するPythonプログラムを書いてください。
回答 :
以下のPythonコードは、Yahoo Finance APIを使用して、東証銘柄の過去データを取得し、Prophetを使用して時系列予測を行うサンプルコードです。このコードを実行するには、必要なPythonライブラリとして、pandas、yfinance、Prophetをインストールする必要があります。
import pandas as pd
import yfinance as yf
from fbprophet import Prophet
# 銘柄コードを指定
code = '9984.T'
# Yahoo Finance APIから株価データを取得
df = yf.download(code, start="2010-01-01", end="2022-03-24")
# Prophet用にデータを整形
df = pd.DataFrame({'ds': df.index, 'y': df['Adj Close']}).reset_index(drop=True)
# Prophetのモデルを作成
model = Prophet()
model.fit(df)
# 未来の日付を生成
future = model.make_future_dataframe(periods=365)
# 予測を実行
forecast = model.predict(future)
# 予測結果を出力
fig = model.plot(forecast)
上記のコードでは、銘柄コードを'9984.T'として指定していますが、予測したい銘柄コードに変更してください。また、startとendで取得するデータの期間を指定しています。ここでは2010年1月1日から2022年3月24日までのデータを取得していますが、必要に応じて変更してください。
また、Prophetのモデルを作成する際に、様々なパラメータを指定することができます。例えば、growthやseasonality_mode、changepoint_prior_scaleなどがあります。詳細については、Prophetの公式ドキュメントを参照してください。
実行結果 :
見事に予測グラフができました。
上記では、銘柄は「NEXT FUNDS 日経225連動型上場投信 (1321)」を指定して、2017-01-01から 2022-03-24の過去データを取得して、365日先までの予測をさせています。
過去10年までにすると、株価の上昇トレンドが強すぎて一本線みたいなグラフになったので、適当に5年間までのデータに絞っています。
改良例:
上記のコードだとグラフが画像として表示されて使い勝手が悪いので、plotlyを使う形に改良してもらいました。
改良例1
# 予測結果をplotlyで表示
fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
# 実際の株価を表示
fig.add_trace(go.Scatter(x=df['ds'], y=df['y'], name='Actual'), row=1, col=1)
# 予測結果を表示
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], name='Predicted'), row=2, col=1)
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], name='Upper Bound'), row=2, col=1)
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], name='Lower Bound'), row=2, col=1)
fig.update_layout(title='Stock Price Prediction for '+code, xaxis_title='Date', yaxis_title='Price')
fig.show()
改良例2
# グラフの描画
fig = go.Figure()
fig.add_trace(go.Scatter(x=df['ds'], y=df['y'], name='actual'))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], name='forecast'))
fig.update_layout(title=f'{code} 株価予測')
fig.show()