Coverage for python3/src/restapi.py: 67%
135 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-10 02:24 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-10 02:24 +0000
1#!../bin/python
2import sqlite3
3import logging
4from logging.handlers import RotatingFileHandler
5from fastapi import FastAPI, Form, HTTPException, Response, Cookie, Depends
6from fastapi.security import OAuth2PasswordBearer
7import jwt
8from datetime import datetime, timedelta, timezone
9import os
11python_path = os.environ.get("PYTHONPATH")
12log_file_path = f"{python_path}logs/trading.log"
13file_format = "%(asctime)s - %(levelname)s - %(message)s"
14stdout_format = "%(levelname)s: %(message)s"
15# Create handlers
16file_handler = RotatingFileHandler(
17 log_file_path,
18 maxBytes=10000,
19 backupCount=3)
20file_handler.setLevel(logging.INFO)
21file_handler.setFormatter(logging.Formatter(file_format))
23stream_handler = logging.StreamHandler()
24stream_handler.setLevel(logging.INFO)
25stream_handler.setFormatter(logging.Formatter(stdout_format))
27logger = logging.getLogger(__name__)
28logger.setLevel(logging.INFO)
29logger.addHandler(file_handler)
30logger.addHandler(stream_handler)
32app = FastAPI()
35# Secret key to sign JWT tokens
36SECRET_KEY = "your-secret-key"
37ALGORITHM = "HS256"
40# Token expiration time in minutes
41ACCESS_TOKEN_EXPIRE_MINUTES = 20
42REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
43oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
46class Sqlite_db(): # pragma: no cove
47 def database_create_user_table(self):
48 with sqlite3.connect('trading_database.db') as connection:
49 cursor = connection.cursor()
50 cursor.execute('''CREATE TABLE IF NOT EXISTS users (
51 id INTEGER PRIMARY KEY,
52 username TEXT UNIQUE,
53 password TEXT
54 )''')
56 def database_create_user(self, username: str, password: str):
57 with sqlite3.connect('trading_database.db') as connection:
58 cursor = connection.cursor()
59 try:
60 cursor.execute(
61 'INSERT INTO users (username, password) VALUES (?, ?)',
62 (username, password))
63 except sqlite3.IntegrityError:
64 print("Warn: The user alredy exists in the database")
66 def database_select_user_by_username(self, username: str) -> tuple | None:
67 with sqlite3.connect('trading_database.db') as connection:
68 cursor = connection.cursor()
69 cursor.execute(
70 'SELECT * FROM users WHERE username = ?',
71 (username,))
72 row = cursor.fetchone()
73 return row
75 def exists_robinhood_profile(self, user_id: int) -> bool:
76 with sqlite3.connect('trading_database.db') as connection:
77 cursor = connection.cursor()
78 sql_query = '''SELECT EXISTS(SELECT 1 FROM robinhood_profiles
79 WHERE user_id = ? LIMIT 1)'''
80 cursor.execute(sql_query, (user_id,))
81 exists = cursor.fetchone()[0]
83 if exists:
84 logger.info('A robinhood profile with a user_id of '
85 f'{user_id} is in our database already.')
86 return True
87 logger.info('System did not find a robinhood profile in our database '
88 f'with a user_id of {user_id}.')
89 return False
91 def update_robinhood_profile(
92 self,
93 robinhood_username: str,
94 robinhood_password: str,
95 user_id: int) -> bool:
96 logger.info('Request to update the username of a the robinhood '
97 f'profile for user_id {user_id} opened.')
99 return False
101 def create_robinhood_profile(
102 self,
103 robinhood_username: str,
104 robinhood_password: str,
105 user_id: int) -> bool:
106 logger.info('Request to create a robinhood profile for the user_id '
107 f'{user_id} opened.')
108 with sqlite3.connect('trading_database.db') as connection:
109 sql_query = '''INSERT INTO robinhood_profiles
110(username, password, user_id) VALUES (?, ?, ?);'''
111 cursor = connection.cursor()
112 cursor.execute(
113 sql_query,
114 (
115 robinhood_username,
116 robinhood_password,
117 user_id
118 )
119 )
120 if cursor.rowcount == 1:
121 logger.info('Request to create a robinhood profile for the '
122 f'user_id {user_id} completed and closed.')
123 return True
124 return False
127def validate_user(user_record: tuple, username: str, password: str) -> bool:
128 if username != user_record[1]:
129 return False
130 if password != user_record[2]:
131 return False
132 return True
135# Function to create access token
136def create_access_token(data: dict, expires_delta: timedelta):
137 to_encode = data.copy()
138 expire = datetime.utcnow() + expires_delta
139 to_encode.update({"exp": expire})
140 encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
141 return encoded_jwt
144def auth_user(session_token: str = Cookie(None, alias='access_token')):
145 credentials_exceptions = HTTPException(
146 status_code=401,
147 detail="Could not validate credentials.",
148 headers={"WWW-Authenticate": "Bearer"})
149 credentials_expired_exceptions = HTTPException(
150 status_code=401,
151 detail="Token has expired.",
152 headers={"WWW-Authenticate": "Bearer"})
153 credentials_audience_exceptions = HTTPException(
154 status_code=401,
155 detail="Token was provided to a invalid audience.",
156 headers={"WWW-Authenticate": "Bearer"})
157 credentials_issuer_exceptions = HTTPException(
158 status_code=401,
159 detail="Token is from a invalid issuer.",
160 headers={"WWW-Authenticate": "Bearer"})
162 if not session_token:
163 logger.info("auth_user: No session token was provided by requester.")
164 raise credentials_exceptions
165 try:
166 payload = jwt.decode(
167 session_token,
168 SECRET_KEY,
169 audience="tradingapi",
170 issuer="tradingapi",
171 algorithms=[ALGORITHM])
172 logger.info("auth_user: The token provided was a valid token")
173 return payload
174 except jwt.ExpiredSignatureError:
175 logger.warning((
176 "auth_user: The session token provided",
177 " by the requester has expired."))
178 raise credentials_expired_exceptions
179 except jwt.InvalidAudienceError:
180 logger.warning((
181 "auth_user: The session token was",
182 " provided to a invalid audience."))
183 raise credentials_audience_exceptions
184 except jwt.InvalidIssuerError:
185 logger.warning((
186 "auth_user: The session token",
187 " is from a invalid issuer."))
188 raise credentials_issuer_exceptions
189 except jwt.InvalidTokenError:
190 logger.warning((
191 "auth_user: The session token provided",
192 " by the requester is not valid."))
193 raise credentials_exceptions
196@app.get("/")
197def home():
198 return "Hello World"
201@app.post("/login")
202async def login(
203 response: Response,
204 username: str = Form(),
205 password: str = Form()):
206 db = Sqlite_db()
207 user_record = db.database_select_user_by_username(username)
208 if user_record and validate_user(user_record, username, password):
209 access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
210 cookie_expires = datetime.utcnow() + access_token_expires
211 cookie_expires_utc = cookie_expires.replace(tzinfo=timezone.utc)
212 access_token = create_access_token(
213 data={
214 "sub": username,
215 "iss": "tradingapi",
216 "aud": "tradingapi"},
217 expires_delta=access_token_expires)
218 response.set_cookie(
219 key="access_token",
220 value=access_token,
221 httponly=True,
222 secure=True,
223 expires=cookie_expires_utc)
224 return "User has been logged in"
225 raise HTTPException(status_code=401, detail="Invalid credentials")
228@app.get("/protected-route")
229async def protected_route(current_user: dict = Depends(auth_user)):
230 return {"message": "This is a protected route", "user": current_user}
233@app.post("/api/v1/robinhood_profile")
234def update_robinhood_profile(
235 current_user: dict = Depends(auth_user),
236 robinhood_user: str = Form(),
237 robinhood_password: str = Form()):
238 db = Sqlite_db()
239 user_record = db.database_select_user_by_username(current_user['sub'])
240 if not user_record:
241 return HTTPException(
242 status_code=500,
243 detail="user was not found in db.")
244 database_user_id = user_record[0]
245 if db.exists_robinhood_profile(database_user_id):
246 db.update_robinhood_profile(
247 robinhood_user,
248 robinhood_password,
249 database_user_id)
250 else:
251 db.create_robinhood_profile(
252 robinhood_user,
253 robinhood_password,
254 database_user_id)
255 return "Robinhood user updated in profile"
258if __name__ == "__main__": # pragma: no cover
259 import uvicorn
260 uvicorn.run(app, host="127.0.0.1", port=8000)