Update projects/conda.md
This commit is contained in:
@@ -1,3 +1,205 @@
|
||||
Absolutely! Let's simplify the project to create a Minimum Viable Product (MVP). The MVP will focus on the core functionality: fetching historical stock data, storing it in a SQLite database, and providing a simple web interface to display the actual and predicted stock prices.
|
||||
|
||||
### Simplified Project Structure
|
||||
|
||||
```
|
||||
stock_prediction/
|
||||
├── data/
|
||||
│ └── stock_data.db
|
||||
├── scripts/
|
||||
│ └── fetch_stock_data.py
|
||||
├── app/
|
||||
│ └── stock_prediction_app.py
|
||||
├── environment.yml
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
### Simplified `fetch_stock_data.py`
|
||||
|
||||
This script fetches historical stock data and stores it in a SQLite database.
|
||||
|
||||
```python
|
||||
import argparse
|
||||
import yfinance as yf
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
def fetch_stock_data(symbol, start_date, end_date, db_path):
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
data = yf.download(symbol, start=start_date, end=end_date)
|
||||
table_name = f"{symbol}_prices"
|
||||
data.to_sql(name=table_name, con=conn, if_exists="replace")
|
||||
conn.close()
|
||||
print(f"Data for {symbol} stored in the database.")
|
||||
except sqlite3.Error as e:
|
||||
print(f"Error while storing data for {symbol}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error while fetching data for {symbol}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Fetch historical stock data from Yahoo Finance.")
|
||||
parser.add_argument("--symbol", type=str, default="^GSPC", help="Stock symbol (default: ^GSPC)")
|
||||
parser.add_argument("--start_date", type=str, default="2000-01-01", help="Start date (default: 2000-01-01)")
|
||||
parser.add_argument("--end_date", type=str, default="2023-05-31", help="End date (default: 2023-05-31)")
|
||||
args = parser.parse_args()
|
||||
|
||||
symbol = args.symbol
|
||||
start_date = args.start_date
|
||||
end_date = args.end_date
|
||||
|
||||
db_path = os.path.join(os.path.dirname(__file__), '../data/stock_data.db')
|
||||
fetch_stock_data(symbol, start_date, end_date, db_path)
|
||||
```
|
||||
|
||||
### Simplified `stock_prediction_app.py`
|
||||
|
||||
This Dash application fetches data from the SQLite database and displays it. The prediction functionality is kept simple.
|
||||
|
||||
```python
|
||||
import dash
|
||||
import dash_core_components as dcc
|
||||
import dash_html_components as html
|
||||
from dash.dependencies import Input, Output
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import LSTM, Dense
|
||||
import plotly.graph_objects as go
|
||||
import os
|
||||
|
||||
app = dash.Dash(__name__)
|
||||
|
||||
app.layout = html.Div([
|
||||
html.H1("Stock Price Prediction"),
|
||||
html.Div([
|
||||
html.Label("Select Stock Symbol"),
|
||||
dcc.Dropdown(
|
||||
id="stock-dropdown",
|
||||
options=[{"label": "S&P 500", "value": "^GSPC"},
|
||||
{"label": "Dow Jones", "value": "^DJI"},
|
||||
{"label": "Nasdaq", "value": "^IXIC"}],
|
||||
value="^GSPC"
|
||||
)
|
||||
]),
|
||||
html.Div([
|
||||
dcc.Graph(id="stock-graph")
|
||||
])
|
||||
])
|
||||
|
||||
@app.callback(Output("stock-graph", "figure"),
|
||||
[Input("stock-dropdown", "value")])
|
||||
def update_graph(stock_symbol):
|
||||
try:
|
||||
db_path = os.path.join(os.path.dirname(__file__), "../data/stock_data.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
data = pd.read_sql_query(f"SELECT Date, Close FROM {stock_symbol}_prices", conn)
|
||||
conn.close()
|
||||
|
||||
scaler = MinMaxScaler(feature_range=(0, 1))
|
||||
scaled_data = scaler.fit_transform(data["Close"].values.reshape(-1, 1))
|
||||
|
||||
lookback = 60
|
||||
X, y = [], []
|
||||
for i in range(lookback, len(scaled_data)):
|
||||
X.append(scaled_data[i - lookback:i, 0])
|
||||
y.append(scaled_data[i, 0])
|
||||
X, y = np.array(X), np.array(y)
|
||||
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
|
||||
|
||||
model = Sequential()
|
||||
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
|
||||
model.add(LSTM(units=50))
|
||||
model.add(Dense(1))
|
||||
model.compile(loss="mean_squared_error", optimizer="adam")
|
||||
model.fit(X, y, epochs=1, batch_size=32)
|
||||
|
||||
last_data = scaled_data[-lookback:]
|
||||
X_test = np.array([last_data])
|
||||
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
|
||||
predicted_prices = model.predict(X_test)
|
||||
predicted_prices = scaler.inverse_transform(predicted_prices)
|
||||
|
||||
fig = go.Figure()
|
||||
fig.add_trace(go.Scatter(x=data["Date"], y=data["Close"], name="Actual Price"))
|
||||
fig.add_trace(go.Scatter(x=data["Date"][-len(predicted_prices):], y=predicted_prices.flatten(), name="Predicted Price"))
|
||||
fig.update_layout(title=f"{stock_symbol} Stock Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
||||
|
||||
return fig
|
||||
except sqlite3.Error as e:
|
||||
print(f"Error while fetching data for {stock_symbol}: {e}")
|
||||
return go.Figure()
|
||||
except Exception as e:
|
||||
print(f"Error while updating graph for {stock_symbol}: {e}")
|
||||
return go.Figure()
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run_server(debug=True)
|
||||
```
|
||||
|
||||
### `environment.yml`:
|
||||
|
||||
```yaml
|
||||
name: stock_prediction
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9
|
||||
- pip
|
||||
- pip:
|
||||
- -r requirements.txt
|
||||
```
|
||||
|
||||
### `requirements.txt`:
|
||||
|
||||
```
|
||||
dash
|
||||
pandas
|
||||
numpy
|
||||
yfinance
|
||||
scikit-learn
|
||||
tensorflow
|
||||
keras
|
||||
plotly
|
||||
```
|
||||
|
||||
### Instructions for Running the MVP
|
||||
|
||||
1. **Create the Conda Environment**:
|
||||
|
||||
```bash
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
|
||||
2. **Activate the Conda Environment**:
|
||||
|
||||
```bash
|
||||
conda activate stock_prediction
|
||||
```
|
||||
|
||||
3. **Fetch Historical Stock Data**:
|
||||
|
||||
```bash
|
||||
python scripts/fetch_stock_data.py --symbol ^GSPC --start_date 2000-01-01 --end_date 2023-05-31
|
||||
```
|
||||
|
||||
4. **Run the Dash Application**:
|
||||
|
||||
```bash
|
||||
python app/stock_prediction_app.py
|
||||
```
|
||||
|
||||
5. **Open the Web Application**:
|
||||
|
||||
Open your web browser and go to `http://localhost:8050` to access the application. Select a stock symbol from the dropdown menu to see the actual and predicted stock prices on the graph.
|
||||
|
||||
This simplified version focuses on the essential features and ensures that the core functionalities work properly. Once this MVP is stable, you can gradually add more features and improvements.
|
||||
|
||||
---
|
||||
|
||||
Here's the complete project structure and code in one document:
|
||||
|
||||
Project Structure:
|
||||
|
||||
Reference in New Issue
Block a user