Update projects/conda.md

This commit is contained in:
2024-05-24 04:13:06 +00:00
parent df68f6d09b
commit fe919a538c

View File

@@ -1,3 +1,205 @@
Absolutely! Let's simplify the project to create a Minimum Viable Product (MVP). The MVP will focus on the core functionality: fetching historical stock data, storing it in a SQLite database, and providing a simple web interface to display the actual and predicted stock prices.
### Simplified Project Structure
```
stock_prediction/
├── data/
│ └── stock_data.db
├── scripts/
│ └── fetch_stock_data.py
├── app/
│ └── stock_prediction_app.py
├── environment.yml
└── requirements.txt
```
### Simplified `fetch_stock_data.py`
This script fetches historical stock data and stores it in a SQLite database.
```python
import argparse
import yfinance as yf
import pandas as pd
import sqlite3
import os
def fetch_stock_data(symbol, start_date, end_date, db_path):
try:
conn = sqlite3.connect(db_path)
data = yf.download(symbol, start=start_date, end=end_date)
table_name = f"{symbol}_prices"
data.to_sql(name=table_name, con=conn, if_exists="replace")
conn.close()
print(f"Data for {symbol} stored in the database.")
except sqlite3.Error as e:
print(f"Error while storing data for {symbol}: {e}")
except Exception as e:
print(f"Error while fetching data for {symbol}: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Fetch historical stock data from Yahoo Finance.")
parser.add_argument("--symbol", type=str, default="^GSPC", help="Stock symbol (default: ^GSPC)")
parser.add_argument("--start_date", type=str, default="2000-01-01", help="Start date (default: 2000-01-01)")
parser.add_argument("--end_date", type=str, default="2023-05-31", help="End date (default: 2023-05-31)")
args = parser.parse_args()
symbol = args.symbol
start_date = args.start_date
end_date = args.end_date
db_path = os.path.join(os.path.dirname(__file__), '../data/stock_data.db')
fetch_stock_data(symbol, start_date, end_date, db_path)
```
### Simplified `stock_prediction_app.py`
This Dash application fetches data from the SQLite database and displays it. The prediction functionality is kept simple.
```python
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import pandas as pd
import sqlite3
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
import plotly.graph_objects as go
import os
app = dash.Dash(__name__)
app.layout = html.Div([
html.H1("Stock Price Prediction"),
html.Div([
html.Label("Select Stock Symbol"),
dcc.Dropdown(
id="stock-dropdown",
options=[{"label": "S&P 500", "value": "^GSPC"},
{"label": "Dow Jones", "value": "^DJI"},
{"label": "Nasdaq", "value": "^IXIC"}],
value="^GSPC"
)
]),
html.Div([
dcc.Graph(id="stock-graph")
])
])
@app.callback(Output("stock-graph", "figure"),
[Input("stock-dropdown", "value")])
def update_graph(stock_symbol):
try:
db_path = os.path.join(os.path.dirname(__file__), "../data/stock_data.db")
conn = sqlite3.connect(db_path)
data = pd.read_sql_query(f"SELECT Date, Close FROM {stock_symbol}_prices", conn)
conn.close()
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data["Close"].values.reshape(-1, 1))
lookback = 60
X, y = [], []
for i in range(lookback, len(scaled_data)):
X.append(scaled_data[i - lookback:i, 0])
y.append(scaled_data[i, 0])
X, y = np.array(X), np.array(y)
X = np.reshape(X, (X.shape[0], X.shape[1], 1))
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(1))
model.compile(loss="mean_squared_error", optimizer="adam")
model.fit(X, y, epochs=1, batch_size=32)
last_data = scaled_data[-lookback:]
X_test = np.array([last_data])
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
predicted_prices = model.predict(X_test)
predicted_prices = scaler.inverse_transform(predicted_prices)
fig = go.Figure()
fig.add_trace(go.Scatter(x=data["Date"], y=data["Close"], name="Actual Price"))
fig.add_trace(go.Scatter(x=data["Date"][-len(predicted_prices):], y=predicted_prices.flatten(), name="Predicted Price"))
fig.update_layout(title=f"{stock_symbol} Stock Price Prediction", xaxis_title="Date", yaxis_title="Price")
return fig
except sqlite3.Error as e:
print(f"Error while fetching data for {stock_symbol}: {e}")
return go.Figure()
except Exception as e:
print(f"Error while updating graph for {stock_symbol}: {e}")
return go.Figure()
if __name__ == "__main__":
app.run_server(debug=True)
```
### `environment.yml`:
```yaml
name: stock_prediction
channels:
- defaults
dependencies:
- python=3.9
- pip
- pip:
- -r requirements.txt
```
### `requirements.txt`:
```
dash
pandas
numpy
yfinance
scikit-learn
tensorflow
keras
plotly
```
### Instructions for Running the MVP
1. **Create the Conda Environment**:
```bash
conda env create -f environment.yml
```
2. **Activate the Conda Environment**:
```bash
conda activate stock_prediction
```
3. **Fetch Historical Stock Data**:
```bash
python scripts/fetch_stock_data.py --symbol ^GSPC --start_date 2000-01-01 --end_date 2023-05-31
```
4. **Run the Dash Application**:
```bash
python app/stock_prediction_app.py
```
5. **Open the Web Application**:
Open your web browser and go to `http://localhost:8050` to access the application. Select a stock symbol from the dropdown menu to see the actual and predicted stock prices on the graph.
This simplified version focuses on the essential features and ensures that the core functionalities work properly. Once this MVP is stable, you can gradually add more features and improvements.
---
Here's the complete project structure and code in one document: Here's the complete project structure and code in one document:
Project Structure: Project Structure: