Update projects/conda.md
This commit is contained in:
@@ -1,3 +1,205 @@
|
|||||||
|
Absolutely! Let's simplify the project to create a Minimum Viable Product (MVP). The MVP will focus on the core functionality: fetching historical stock data, storing it in a SQLite database, and providing a simple web interface to display the actual and predicted stock prices.
|
||||||
|
|
||||||
|
### Simplified Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
stock_prediction/
|
||||||
|
├── data/
|
||||||
|
│ └── stock_data.db
|
||||||
|
├── scripts/
|
||||||
|
│ └── fetch_stock_data.py
|
||||||
|
├── app/
|
||||||
|
│ └── stock_prediction_app.py
|
||||||
|
├── environment.yml
|
||||||
|
└── requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Simplified `fetch_stock_data.py`
|
||||||
|
|
||||||
|
This script fetches historical stock data and stores it in a SQLite database.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import argparse
|
||||||
|
import yfinance as yf
|
||||||
|
import pandas as pd
|
||||||
|
import sqlite3
|
||||||
|
import os
|
||||||
|
|
||||||
|
def fetch_stock_data(symbol, start_date, end_date, db_path):
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
data = yf.download(symbol, start=start_date, end=end_date)
|
||||||
|
table_name = f"{symbol}_prices"
|
||||||
|
data.to_sql(name=table_name, con=conn, if_exists="replace")
|
||||||
|
conn.close()
|
||||||
|
print(f"Data for {symbol} stored in the database.")
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
print(f"Error while storing data for {symbol}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while fetching data for {symbol}: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Fetch historical stock data from Yahoo Finance.")
|
||||||
|
parser.add_argument("--symbol", type=str, default="^GSPC", help="Stock symbol (default: ^GSPC)")
|
||||||
|
parser.add_argument("--start_date", type=str, default="2000-01-01", help="Start date (default: 2000-01-01)")
|
||||||
|
parser.add_argument("--end_date", type=str, default="2023-05-31", help="End date (default: 2023-05-31)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
symbol = args.symbol
|
||||||
|
start_date = args.start_date
|
||||||
|
end_date = args.end_date
|
||||||
|
|
||||||
|
db_path = os.path.join(os.path.dirname(__file__), '../data/stock_data.db')
|
||||||
|
fetch_stock_data(symbol, start_date, end_date, db_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Simplified `stock_prediction_app.py`
|
||||||
|
|
||||||
|
This Dash application fetches data from the SQLite database and displays it. The prediction functionality is kept simple.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import dash
|
||||||
|
import dash_core_components as dcc
|
||||||
|
import dash_html_components as html
|
||||||
|
from dash.dependencies import Input, Output
|
||||||
|
import pandas as pd
|
||||||
|
import sqlite3
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
|
from tensorflow.keras.models import Sequential
|
||||||
|
from tensorflow.keras.layers import LSTM, Dense
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import os
|
||||||
|
|
||||||
|
app = dash.Dash(__name__)
|
||||||
|
|
||||||
|
app.layout = html.Div([
|
||||||
|
html.H1("Stock Price Prediction"),
|
||||||
|
html.Div([
|
||||||
|
html.Label("Select Stock Symbol"),
|
||||||
|
dcc.Dropdown(
|
||||||
|
id="stock-dropdown",
|
||||||
|
options=[{"label": "S&P 500", "value": "^GSPC"},
|
||||||
|
{"label": "Dow Jones", "value": "^DJI"},
|
||||||
|
{"label": "Nasdaq", "value": "^IXIC"}],
|
||||||
|
value="^GSPC"
|
||||||
|
)
|
||||||
|
]),
|
||||||
|
html.Div([
|
||||||
|
dcc.Graph(id="stock-graph")
|
||||||
|
])
|
||||||
|
])
|
||||||
|
|
||||||
|
@app.callback(Output("stock-graph", "figure"),
|
||||||
|
[Input("stock-dropdown", "value")])
|
||||||
|
def update_graph(stock_symbol):
|
||||||
|
try:
|
||||||
|
db_path = os.path.join(os.path.dirname(__file__), "../data/stock_data.db")
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
data = pd.read_sql_query(f"SELECT Date, Close FROM {stock_symbol}_prices", conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
scaler = MinMaxScaler(feature_range=(0, 1))
|
||||||
|
scaled_data = scaler.fit_transform(data["Close"].values.reshape(-1, 1))
|
||||||
|
|
||||||
|
lookback = 60
|
||||||
|
X, y = [], []
|
||||||
|
for i in range(lookback, len(scaled_data)):
|
||||||
|
X.append(scaled_data[i - lookback:i, 0])
|
||||||
|
y.append(scaled_data[i, 0])
|
||||||
|
X, y = np.array(X), np.array(y)
|
||||||
|
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
|
||||||
|
|
||||||
|
model = Sequential()
|
||||||
|
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
|
||||||
|
model.add(LSTM(units=50))
|
||||||
|
model.add(Dense(1))
|
||||||
|
model.compile(loss="mean_squared_error", optimizer="adam")
|
||||||
|
model.fit(X, y, epochs=1, batch_size=32)
|
||||||
|
|
||||||
|
last_data = scaled_data[-lookback:]
|
||||||
|
X_test = np.array([last_data])
|
||||||
|
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
|
||||||
|
predicted_prices = model.predict(X_test)
|
||||||
|
predicted_prices = scaler.inverse_transform(predicted_prices)
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
fig.add_trace(go.Scatter(x=data["Date"], y=data["Close"], name="Actual Price"))
|
||||||
|
fig.add_trace(go.Scatter(x=data["Date"][-len(predicted_prices):], y=predicted_prices.flatten(), name="Predicted Price"))
|
||||||
|
fig.update_layout(title=f"{stock_symbol} Stock Price Prediction", xaxis_title="Date", yaxis_title="Price")
|
||||||
|
|
||||||
|
return fig
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
print(f"Error while fetching data for {stock_symbol}: {e}")
|
||||||
|
return go.Figure()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while updating graph for {stock_symbol}: {e}")
|
||||||
|
return go.Figure()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run_server(debug=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `environment.yml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: stock_prediction
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.9
|
||||||
|
- pip
|
||||||
|
- pip:
|
||||||
|
- -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### `requirements.txt`:
|
||||||
|
|
||||||
|
```
|
||||||
|
dash
|
||||||
|
pandas
|
||||||
|
numpy
|
||||||
|
yfinance
|
||||||
|
scikit-learn
|
||||||
|
tensorflow
|
||||||
|
keras
|
||||||
|
plotly
|
||||||
|
```
|
||||||
|
|
||||||
|
### Instructions for Running the MVP
|
||||||
|
|
||||||
|
1. **Create the Conda Environment**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda env create -f environment.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Activate the Conda Environment**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda activate stock_prediction
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Fetch Historical Stock Data**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/fetch_stock_data.py --symbol ^GSPC --start_date 2000-01-01 --end_date 2023-05-31
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Run the Dash Application**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python app/stock_prediction_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Open the Web Application**:
|
||||||
|
|
||||||
|
Open your web browser and go to `http://localhost:8050` to access the application. Select a stock symbol from the dropdown menu to see the actual and predicted stock prices on the graph.
|
||||||
|
|
||||||
|
This simplified version focuses on the essential features and ensures that the core functionalities work properly. Once this MVP is stable, you can gradually add more features and improvements.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
Here's the complete project structure and code in one document:
|
Here's the complete project structure and code in one document:
|
||||||
|
|
||||||
Project Structure:
|
Project Structure:
|
||||||
|
|||||||
Reference in New Issue
Block a user