784 lines
24 KiB
Markdown
784 lines
24 KiB
Markdown
Absolutely! Let's simplify the project to create a Minimum Viable Product (MVP). The MVP will focus on the core functionality: fetching historical stock data, storing it in a SQLite database, and providing a simple web interface to display the actual and predicted stock prices.
|
|
|
|
### Simplified Project Structure
|
|
|
|
```
|
|
stock_prediction/
|
|
├── data/
|
|
│ └── stock_data.db
|
|
├── scripts/
|
|
│ └── fetch_stock_data.py
|
|
├── app/
|
|
│ └── stock_prediction_app.py
|
|
├── environment.yml
|
|
└── requirements.txt
|
|
```
|
|
|
|
### Simplified `fetch_stock_data.py`
|
|
|
|
This script fetches historical stock data and stores it in a SQLite database.
|
|
|
|
```python
|
|
import argparse
|
|
import yfinance as yf
|
|
import pandas as pd
|
|
import sqlite3
|
|
import os
|
|
|
|
def fetch_stock_data(symbol, start_date, end_date, db_path):
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
data = yf.download(symbol, start=start_date, end=end_date)
|
|
table_name = f"{symbol}_prices"
|
|
data.to_sql(name=table_name, con=conn, if_exists="replace")
|
|
conn.close()
|
|
print(f"Data for {symbol} stored in the database.")
|
|
except sqlite3.Error as e:
|
|
print(f"Error while storing data for {symbol}: {e}")
|
|
except Exception as e:
|
|
print(f"Error while fetching data for {symbol}: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Fetch historical stock data from Yahoo Finance.")
|
|
parser.add_argument("--symbol", type=str, default="^GSPC", help="Stock symbol (default: ^GSPC)")
|
|
parser.add_argument("--start_date", type=str, default="2000-01-01", help="Start date (default: 2000-01-01)")
|
|
parser.add_argument("--end_date", type=str, default="2023-05-31", help="End date (default: 2023-05-31)")
|
|
args = parser.parse_args()
|
|
|
|
symbol = args.symbol
|
|
start_date = args.start_date
|
|
end_date = args.end_date
|
|
|
|
db_path = os.path.join(os.path.dirname(__file__), '../data/stock_data.db')
|
|
fetch_stock_data(symbol, start_date, end_date, db_path)
|
|
```
|
|
|
|
### Simplified `stock_prediction_app.py`
|
|
|
|
This Dash application fetches data from the SQLite database and displays it. The prediction functionality is kept simple.
|
|
|
|
```python
|
|
import dash
|
|
import dash_core_components as dcc
|
|
import dash_html_components as html
|
|
from dash.dependencies import Input, Output
|
|
import pandas as pd
|
|
import sqlite3
|
|
import numpy as np
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import LSTM, Dense
|
|
import plotly.graph_objects as go
|
|
import os
|
|
|
|
app = dash.Dash(__name__)
|
|
|
|
app.layout = html.Div([
|
|
html.H1("Stock Price Prediction"),
|
|
html.Div([
|
|
html.Label("Select Stock Symbol"),
|
|
dcc.Dropdown(
|
|
id="stock-dropdown",
|
|
options=[{"label": "S&P 500", "value": "^GSPC"},
|
|
{"label": "Dow Jones", "value": "^DJI"},
|
|
{"label": "Nasdaq", "value": "^IXIC"}],
|
|
value="^GSPC"
|
|
)
|
|
]),
|
|
html.Div([
|
|
dcc.Graph(id="stock-graph")
|
|
])
|
|
])
|
|
|
|
@app.callback(Output("stock-graph", "figure"),
|
|
[Input("stock-dropdown", "value")])
|
|
def update_graph(stock_symbol):
|
|
try:
|
|
db_path = os.path.join(os.path.dirname(__file__), "../data/stock_data.db")
|
|
conn = sqlite3.connect(db_path)
|
|
data = pd.read_sql_query(f"SELECT Date, Close FROM {stock_symbol}_prices", conn)
|
|
conn.close()
|
|
|
|
scaler = MinMaxScaler(feature_range=(0, 1))
|
|
scaled_data = scaler.fit_transform(data["Close"].values.reshape(-1, 1))
|
|
|
|
lookback = 60
|
|
X, y = [], []
|
|
for i in range(lookback, len(scaled_data)):
|
|
X.append(scaled_data[i - lookback:i, 0])
|
|
y.append(scaled_data[i, 0])
|
|
X, y = np.array(X), np.array(y)
|
|
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
|
|
|
|
model = Sequential()
|
|
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
|
|
model.add(LSTM(units=50))
|
|
model.add(Dense(1))
|
|
model.compile(loss="mean_squared_error", optimizer="adam")
|
|
model.fit(X, y, epochs=1, batch_size=32)
|
|
|
|
last_data = scaled_data[-lookback:]
|
|
X_test = np.array([last_data])
|
|
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
|
|
predicted_prices = model.predict(X_test)
|
|
predicted_prices = scaler.inverse_transform(predicted_prices)
|
|
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(x=data["Date"], y=data["Close"], name="Actual Price"))
|
|
fig.add_trace(go.Scatter(x=data["Date"][-len(predicted_prices):], y=predicted_prices.flatten(), name="Predicted Price"))
|
|
fig.update_layout(title=f"{stock_symbol} Stock Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
|
|
|
return fig
|
|
except sqlite3.Error as e:
|
|
print(f"Error while fetching data for {stock_symbol}: {e}")
|
|
return go.Figure()
|
|
except Exception as e:
|
|
print(f"Error while updating graph for {stock_symbol}: {e}")
|
|
return go.Figure()
|
|
|
|
if __name__ == "__main__":
|
|
app.run_server(debug=True)
|
|
```
|
|
|
|
### `environment.yml`:
|
|
|
|
```yaml
|
|
name: stock_prediction
|
|
channels:
|
|
- defaults
|
|
dependencies:
|
|
- python=3.9
|
|
- pip
|
|
- pip:
|
|
- -r requirements.txt
|
|
```
|
|
|
|
### `requirements.txt`:
|
|
|
|
```
|
|
dash
|
|
pandas
|
|
numpy
|
|
yfinance
|
|
scikit-learn
|
|
tensorflow
|
|
keras
|
|
plotly
|
|
```
|
|
|
|
### Instructions for Running the MVP
|
|
|
|
1. **Create the Conda Environment**:
|
|
|
|
```bash
|
|
conda env create -f environment.yml
|
|
```
|
|
|
|
2. **Activate the Conda Environment**:
|
|
|
|
```bash
|
|
conda activate stock_prediction
|
|
```
|
|
|
|
3. **Fetch Historical Stock Data**:
|
|
|
|
```bash
|
|
python scripts/fetch_stock_data.py --symbol ^GSPC --start_date 2000-01-01 --end_date 2023-05-31
|
|
```
|
|
|
|
4. **Run the Dash Application**:
|
|
|
|
```bash
|
|
python app/stock_prediction_app.py
|
|
```
|
|
|
|
5. **Open the Web Application**:
|
|
|
|
Open your web browser and go to `http://localhost:8050` to access the application. Select a stock symbol from the dropdown menu to see the actual and predicted stock prices on the graph.
|
|
|
|
This simplified version focuses on the essential features and ensures that the core functionalities work properly. Once this MVP is stable, you can gradually add more features and improvements.
|
|
|
|
---
|
|
|
|
Here's the complete project structure and code in one document:
|
|
|
|
Project Structure:
|
|
```
|
|
stock_prediction/
|
|
├── data/
|
|
│ └── stock_data.db
|
|
├── scripts/
|
|
│ ├── fetch_stock_data.py
|
|
│ └── verify_data.py
|
|
├── app/
|
|
│ └── stock_prediction_app.py
|
|
├── environment.yml
|
|
└── requirements.txt
|
|
```
|
|
|
|
1. `stock_prediction/data/stock_data.db`:
|
|
- This is the SQLite database file that will store the historical stock data.
|
|
|
|
2. `stock_prediction/scripts/fetch_stock_data.py`:
|
|
```python
|
|
import argparse
|
|
import yfinance as yf
|
|
import pandas as pd
|
|
import sqlite3
|
|
|
|
parser = argparse.ArgumentParser(description="Fetch historical stock data from Yahoo Finance.")
|
|
parser.add_argument("--symbol", type=str, default="^GSPC", help="Stock symbol (default: ^GSPC)")
|
|
parser.add_argument("--start_date", type=str, default="2000-01-01", help="Start date (default: 2000-01-01)")
|
|
parser.add_argument("--end_date", type=str, default="2023-05-31", help="End date (default: 2023-05-31)")
|
|
args = parser.parse_args()
|
|
|
|
symbol = args.symbol
|
|
start_date = args.start_date
|
|
end_date = args.end_date
|
|
|
|
conn = sqlite3.connect("../data/stock_data.db")
|
|
|
|
data = yf.download(symbol, start=start_date, end=end_date)
|
|
|
|
table_name = f"{symbol}_prices"
|
|
data.to_sql(name=table_name, con=conn, if_exists="replace")
|
|
|
|
print(f"Data for {symbol} stored in the database.")
|
|
|
|
conn.close()
|
|
```
|
|
|
|
3. `stock_prediction/scripts/verify_data.py`:
|
|
```python
|
|
import sqlite3
|
|
import pandas as pd
|
|
|
|
conn = sqlite3.connect("../data/stock_data.db")
|
|
|
|
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table'", conn)
|
|
print("Tables in the database:")
|
|
print(tables)
|
|
|
|
table_name = "^GSPC_prices"
|
|
data = pd.read_sql_query(f"SELECT * FROM {table_name}", conn)
|
|
print(f"\nData from the {table_name} table:")
|
|
print(data.head())
|
|
|
|
conn.close()
|
|
```
|
|
|
|
4. `stock_prediction/app/stock_prediction_app.py`:
|
|
```python
|
|
import dash
|
|
import dash_core_components as dcc
|
|
import dash_html_components as html
|
|
from dash.dependencies import Input, Output
|
|
import pandas as pd
|
|
import sqlite3
|
|
import numpy as np
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import LSTM, Dense
|
|
import plotly.graph_objects as go
|
|
|
|
app = dash.Dash(__name__)
|
|
|
|
app.layout = html.Div([
|
|
html.H1("Stock Price Prediction"),
|
|
html.Div([
|
|
html.Label("Select Stock Symbol"),
|
|
dcc.Dropdown(
|
|
id="stock-dropdown",
|
|
options=[{"label": "S&P 500", "value": "^GSPC"},
|
|
{"label": "Dow Jones", "value": "^DJI"},
|
|
{"label": "Nasdaq", "value": "^IXIC"}],
|
|
value="^GSPC"
|
|
)
|
|
]),
|
|
html.Div([
|
|
dcc.Graph(id="stock-graph")
|
|
])
|
|
])
|
|
|
|
@app.callback(Output("stock-graph", "figure"),
|
|
[Input("stock-dropdown", "value")])
|
|
def update_graph(stock_symbol):
|
|
conn = sqlite3.connect("../data/stock_data.db")
|
|
data = pd.read_sql_query(f"SELECT Date, Close FROM {stock_symbol}_prices", conn)
|
|
conn.close()
|
|
|
|
scaler = MinMaxScaler(feature_range=(0, 1))
|
|
scaled_data = scaler.fit_transform(data["Close"].values.reshape(-1, 1))
|
|
|
|
lookback = 60
|
|
X, y = [], []
|
|
for i in range(lookback, len(scaled_data)):
|
|
X.append(scaled_data[i - lookback:i, 0])
|
|
y.append(scaled_data[i, 0])
|
|
X, y = np.array(X), np.array(y)
|
|
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
|
|
|
|
model = Sequential()
|
|
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
|
|
model.add(LSTM(units=50))
|
|
model.add(Dense(1))
|
|
model.compile(loss="mean_squared_error", optimizer="adam")
|
|
model.fit(X, y, epochs=10, batch_size=32)
|
|
|
|
last_data = scaled_data[-lookback:]
|
|
X_test = []
|
|
for i in range(lookback, len(last_data)):
|
|
X_test.append(last_data[i - lookback:i, 0])
|
|
X_test = np.array(X_test)
|
|
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
|
|
predicted_prices = model.predict(X_test)
|
|
predicted_prices = scaler.inverse_transform(predicted_prices)
|
|
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(x=data["Date"], y=data["Close"], name="Actual Price"))
|
|
fig.add_trace(go.Scatter(x=data["Date"][-len(predicted_prices):], y=predicted_prices.flatten(), name="Predicted Price"))
|
|
fig.update_layout(title=f"{stock_symbol} Stock Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
|
|
|
return fig
|
|
|
|
if __name__ == "__main__":
|
|
app.run_server(debug=True)
|
|
```
|
|
|
|
5. `stock_prediction/environment.yml`:
|
|
```yaml
|
|
name: stock_prediction
|
|
channels:
|
|
- defaults
|
|
dependencies:
|
|
- python=3.9
|
|
- pip
|
|
- pip:
|
|
- -r requirements.txt
|
|
```
|
|
|
|
6. `stock_prediction/requirements.txt`:
|
|
```
|
|
dash
|
|
pandas
|
|
numpy
|
|
yfinance
|
|
scikit-learn
|
|
tensorflow
|
|
keras
|
|
plotly
|
|
```
|
|
|
|
To run the project:
|
|
|
|
1. Create a conda environment using the `environment.yml` file:
|
|
```
|
|
conda env create -f environment.yml
|
|
```
|
|
|
|
2. Activate the conda environment:
|
|
```
|
|
conda activate stock_prediction
|
|
```
|
|
|
|
3. Fetch historical stock data:
|
|
```
|
|
python scripts/fetch_stock_data.py --symbol ^GSPC --start_date 2000-01-01 --end_date 2023-05-31
|
|
```
|
|
You can change the stock symbol and date range as needed.
|
|
|
|
4. Verify the stored data (optional):
|
|
```
|
|
python scripts/verify_data.py
|
|
```
|
|
|
|
5. Run the Dash application:
|
|
```
|
|
python app/stock_prediction_app.py
|
|
```
|
|
|
|
6. Open a web browser and go to `http://localhost:8050` to access the application.
|
|
|
|
7. Select a stock symbol from the dropdown menu to see the actual and predicted stock prices on the graph.
|
|
|
|
This document provides the complete project structure, code files, and instructions for running the stock prediction project using conda for environment management and pip for package management.
|
|
|
|
---
|
|
|
|
Great! Let's update the project to use conda for creating the Python environment and pip for package management. Here's the updated project structure and instructions:
|
|
|
|
Project Structure:
|
|
```
|
|
stock_prediction/
|
|
├── data/
|
|
│ └── stock_data.db
|
|
├── scripts/
|
|
│ ├── fetch_stock_data.py
|
|
│ └── verify_data.py
|
|
├── app/
|
|
│ └── stock_prediction_app.py
|
|
├── environment.yml
|
|
└── requirements.txt
|
|
```
|
|
|
|
1. `stock_prediction/data/stock_data.db`:
|
|
- This is the SQLite database file that will store the historical stock data.
|
|
|
|
2. `stock_prediction/scripts/fetch_stock_data.py`:
|
|
- The content of this file remains the same as in the previous response.
|
|
|
|
3. `stock_prediction/scripts/verify_data.py`:
|
|
- The content of this file remains the same as in the previous response.
|
|
|
|
4. `stock_prediction/app/stock_prediction_app.py`:
|
|
- The content of this file remains the same as in the previous response.
|
|
|
|
5. `stock_prediction/environment.yml`:
|
|
```yaml
|
|
name: stock_prediction
|
|
channels:
|
|
- defaults
|
|
dependencies:
|
|
- python=3.9
|
|
- pip
|
|
- pip:
|
|
- -r requirements.txt
|
|
```
|
|
|
|
6. `stock_prediction/requirements.txt`:
|
|
```
|
|
dash
|
|
pandas
|
|
numpy
|
|
yfinance
|
|
scikit-learn
|
|
tensorflow
|
|
keras
|
|
plotly
|
|
```
|
|
|
|
To run the project:
|
|
|
|
1. Create a conda environment using the `environment.yml` file:
|
|
```
|
|
conda env create -f environment.yml
|
|
```
|
|
|
|
2. Activate the conda environment:
|
|
```
|
|
conda activate stock_prediction
|
|
```
|
|
|
|
3. Fetch historical stock data:
|
|
```
|
|
python scripts/fetch_stock_data.py --symbol ^GSPC --start_date 2000-01-01 --end_date 2023-05-31
|
|
```
|
|
You can change the stock symbol and date range as needed.
|
|
|
|
4. Verify the stored data (optional):
|
|
```
|
|
python scripts/verify_data.py
|
|
```
|
|
|
|
5. Run the Dash application:
|
|
```
|
|
python app/stock_prediction_app.py
|
|
```
|
|
|
|
6. Open a web browser and go to `http://localhost:8050` to access the application.
|
|
|
|
7. Select a stock symbol from the dropdown menu to see the actual and predicted stock prices on the graph.
|
|
|
|
The `environment.yml` file defines the conda environment named `stock_prediction` and specifies the Python version as 3.9. It also includes the `pip` package manager and instructs conda to install the packages listed in the `requirements.txt` file using pip.
|
|
|
|
The `requirements.txt` file remains the same and lists the required Python packages for the project.
|
|
|
|
By using conda for environment management and pip for package management, you can ensure a consistent and reproducible environment for your project.
|
|
|
|
Remember to activate the conda environment before running the scripts or the Dash application.
|
|
|
|
---
|
|
|
|
Project Structure:
|
|
```
|
|
stock_prediction/
|
|
├── data/
|
|
│ └── stock_data.db
|
|
├── scripts/
|
|
│ ├── fetch_stock_data.py
|
|
│ └── verify_data.py
|
|
├── app/
|
|
│ └── stock_prediction_app.py
|
|
└── requirements.txt
|
|
```
|
|
|
|
1. `stock_prediction/data/stock_data.db`:
|
|
- This is the SQLite database file that will store the historical stock data.
|
|
|
|
2. `stock_prediction/scripts/fetch_stock_data.py`:
|
|
```python
|
|
import argparse
|
|
import yfinance as yf
|
|
import pandas as pd
|
|
import sqlite3
|
|
|
|
parser = argparse.ArgumentParser(description="Fetch historical stock data from Yahoo Finance.")
|
|
parser.add_argument("--symbol", type=str, default="^GSPC", help="Stock symbol (default: ^GSPC)")
|
|
parser.add_argument("--start_date", type=str, default="2000-01-01", help="Start date (default: 2000-01-01)")
|
|
parser.add_argument("--end_date", type=str, default="2023-05-31", help="End date (default: 2023-05-31)")
|
|
args = parser.parse_args()
|
|
|
|
symbol = args.symbol
|
|
start_date = args.start_date
|
|
end_date = args.end_date
|
|
|
|
conn = sqlite3.connect("../data/stock_data.db")
|
|
|
|
data = yf.download(symbol, start=start_date, end=end_date)
|
|
|
|
table_name = f"{symbol}_prices"
|
|
data.to_sql(name=table_name, con=conn, if_exists="replace")
|
|
|
|
print(f"Data for {symbol} stored in the database.")
|
|
|
|
conn.close()
|
|
```
|
|
|
|
3. `stock_prediction/scripts/verify_data.py`:
|
|
```python
|
|
import sqlite3
|
|
import pandas as pd
|
|
|
|
conn = sqlite3.connect("../data/stock_data.db")
|
|
|
|
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table'", conn)
|
|
print("Tables in the database:")
|
|
print(tables)
|
|
|
|
table_name = "^GSPC_prices"
|
|
data = pd.read_sql_query(f"SELECT * FROM {table_name}", conn)
|
|
print(f"\nData from the {table_name} table:")
|
|
print(data.head())
|
|
|
|
conn.close()
|
|
```
|
|
|
|
4. `stock_prediction/app/stock_prediction_app.py`:
|
|
```python
|
|
import dash
|
|
import dash_core_components as dcc
|
|
import dash_html_components as html
|
|
from dash.dependencies import Input, Output
|
|
import pandas as pd
|
|
import sqlite3
|
|
import numpy as np
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import LSTM, Dense
|
|
import plotly.graph_objects as go
|
|
|
|
app = dash.Dash(__name__)
|
|
|
|
app.layout = html.Div([
|
|
html.H1("Stock Price Prediction"),
|
|
html.Div([
|
|
html.Label("Select Stock Symbol"),
|
|
dcc.Dropdown(
|
|
id="stock-dropdown",
|
|
options=[{"label": "S&P 500", "value": "^GSPC"},
|
|
{"label": "Dow Jones", "value": "^DJI"},
|
|
{"label": "Nasdaq", "value": "^IXIC"}],
|
|
value="^GSPC"
|
|
)
|
|
]),
|
|
html.Div([
|
|
dcc.Graph(id="stock-graph")
|
|
])
|
|
])
|
|
|
|
@app.callback(Output("stock-graph", "figure"),
|
|
[Input("stock-dropdown", "value")])
|
|
def update_graph(stock_symbol):
|
|
conn = sqlite3.connect("../data/stock_data.db")
|
|
data = pd.read_sql_query(f"SELECT Date, Close FROM {stock_symbol}_prices", conn)
|
|
conn.close()
|
|
|
|
scaler = MinMaxScaler(feature_range=(0, 1))
|
|
scaled_data = scaler.fit_transform(data["Close"].values.reshape(-1, 1))
|
|
|
|
lookback = 60
|
|
X, y = [], []
|
|
for i in range(lookback, len(scaled_data)):
|
|
X.append(scaled_data[i - lookback:i, 0])
|
|
y.append(scaled_data[i, 0])
|
|
X, y = np.array(X), np.array(y)
|
|
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
|
|
|
|
model = Sequential()
|
|
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
|
|
model.add(LSTM(units=50))
|
|
model.add(Dense(1))
|
|
model.compile(loss="mean_squared_error", optimizer="adam")
|
|
model.fit(X, y, epochs=10, batch_size=32)
|
|
|
|
last_data = scaled_data[-lookback:]
|
|
X_test = []
|
|
for i in range(lookback, len(last_data)):
|
|
X_test.append(last_data[i - lookback:i, 0])
|
|
X_test = np.array(X_test)
|
|
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
|
|
predicted_prices = model.predict(X_test)
|
|
predicted_prices = scaler.inverse_transform(predicted_prices)
|
|
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(x=data["Date"], y=data["Close"], name="Actual Price"))
|
|
fig.add_trace(go.Scatter(x=data["Date"][-len(predicted_prices):], y=predicted_prices.flatten(), name="Predicted Price"))
|
|
fig.update_layout(title=f"{stock_symbol} Stock Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
|
|
|
return fig
|
|
|
|
if __name__ == "__main__":
|
|
app.run_server(debug=True)
|
|
```
|
|
|
|
5. `stock_prediction/requirements.txt`:
|
|
```
|
|
dash
|
|
pandas
|
|
numpy
|
|
yfinance
|
|
scikit-learn
|
|
tensorflow
|
|
keras
|
|
plotly
|
|
```
|
|
|
|
Certainly! Let's combine Dash, deep learning, and price prediction capabilities into a working project. We'll create a web application using Dash that allows users to select a stock symbol, trains an LSTM model on the historical data, and displays the predicted stock prices.
|
|
|
|
Here's a step-by-step guide to create the project:
|
|
|
|
Step 1: Set up the environment
|
|
1. Make sure you have Miniconda installed and activated.
|
|
|
|
2. Create a new conda environment for this project:
|
|
```
|
|
conda create --name stock_prediction_app python=3.9
|
|
conda activate stock_prediction_app
|
|
```
|
|
|
|
3. Install the required libraries:
|
|
```
|
|
conda install pandas numpy yfinance scikit-learn tensorflow keras
|
|
conda install -c conda-forge sqlite dash
|
|
```
|
|
|
|
Step 2: Fetch historical stock data
|
|
1. Use the `fetch_stock_data.py` script from the previous example to fetch historical stock data and store it in the SQLite database.
|
|
|
|
Step 3: Create the Dash application
|
|
1. Create a new Python script, e.g., `stock_prediction_app.py`, and add the following code:
|
|
```python
|
|
import dash
|
|
import dash_core_components as dcc
|
|
import dash_html_components as html
|
|
from dash.dependencies import Input, Output
|
|
import pandas as pd
|
|
import sqlite3
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import LSTM, Dense
|
|
|
|
app = dash.Dash(__name__)
|
|
|
|
# Define the layout of the application
|
|
app.layout = html.Div([
|
|
html.H1("Stock Price Prediction"),
|
|
html.Div([
|
|
html.Label("Select Stock Symbol"),
|
|
dcc.Dropdown(
|
|
id="stock-dropdown",
|
|
options=[{"label": "S&P 500", "value": "^GSPC"},
|
|
{"label": "Dow Jones", "value": "^DJI"},
|
|
{"label": "Nasdaq", "value": "^IXIC"}],
|
|
value="^GSPC"
|
|
)
|
|
]),
|
|
html.Div([
|
|
dcc.Graph(id="stock-graph")
|
|
])
|
|
])
|
|
|
|
# Callback to update the graph based on the selected stock symbol
|
|
@app.callback(Output("stock-graph", "figure"),
|
|
[Input("stock-dropdown", "value")])
|
|
def update_graph(stock_symbol):
|
|
# Load data from SQLite database
|
|
conn = sqlite3.connect("stock_data.db")
|
|
data = pd.read_sql_query(f"SELECT Date, Close FROM {stock_symbol}_prices", conn)
|
|
conn.close()
|
|
|
|
# Prepare the data for training
|
|
scaler = MinMaxScaler(feature_range=(0, 1))
|
|
scaled_data = scaler.fit_transform(data["Close"].values.reshape(-1, 1))
|
|
|
|
# Create training data
|
|
lookback = 60
|
|
X, y = [], []
|
|
for i in range(lookback, len(scaled_data)):
|
|
X.append(scaled_data[i - lookback:i, 0])
|
|
y.append(scaled_data[i, 0])
|
|
X, y = np.array(X), np.array(y)
|
|
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
|
|
|
|
# Build and train the LSTM model
|
|
model = Sequential()
|
|
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
|
|
model.add(LSTM(units=50))
|
|
model.add(Dense(1))
|
|
model.compile(loss="mean_squared_error", optimizer="adam")
|
|
model.fit(X, y, epochs=10, batch_size=32)
|
|
|
|
# Make predictions
|
|
last_data = scaled_data[-lookback:]
|
|
X_test = []
|
|
for i in range(lookback, len(last_data)):
|
|
X_test.append(last_data[i - lookback:i, 0])
|
|
X_test = np.array(X_test)
|
|
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
|
|
predicted_prices = model.predict(X_test)
|
|
predicted_prices = scaler.inverse_transform(predicted_prices)
|
|
|
|
# Create the graph figure
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(x=data["Date"], y=data["Close"], name="Actual Price"))
|
|
fig.add_trace(go.Scatter(x=data["Date"][-len(predicted_prices):], y=predicted_prices.flatten(), name="Predicted Price"))
|
|
fig.update_layout(title=f"{stock_symbol} Stock Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
|
|
|
return fig
|
|
|
|
if __name__ == "__main__":
|
|
app.run_server(debug=True)
|
|
```
|
|
|
|
This script creates a Dash application with a dropdown menu to select a stock symbol and a graph to display the actual and predicted stock prices.
|
|
|
|
The `app.layout` defines the structure of the application, including the dropdown menu and the graph.
|
|
|
|
The `@app.callback` decorator is used to update the graph based on the selected stock symbol. It loads the data from the SQLite database, prepares the data for training, builds and trains the LSTM model, makes predictions, and creates the graph figure using the Plotly library.
|
|
|
|
2. Run the script to start the Dash application:
|
|
```
|
|
python stock_prediction_app.py
|
|
```
|
|
|
|
The application will start running on a local server, and you can access it in your web browser at `http://localhost:8050`.
|
|
|
|
3. Select a stock symbol from the dropdown menu to see the actual and predicted stock prices on the graph.
|
|
|
|
That's it! You now have a working project that combines Dash, deep learning, and price prediction capabilities. Users can select a stock symbol, and the application will train an LSTM model on the historical data and display the predicted stock prices alongside the actual prices.
|
|
|
|
This project serves as a starting point, and you can further enhance it by adding more features, improving the model architecture, and incorporating additional data preprocessing techniques.
|
|
|
|
Remember to handle any errors or exceptions that may occur during data loading, model training, or prediction.
|
|
|
|
Feel free to explore and customize the application based on your specific requirements and preferences! |