Coverage for python3/src/restapi.py: 67%

135 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-01-10 02:24 +0000

1#!../bin/python 

2import sqlite3 

3import logging 

4from logging.handlers import RotatingFileHandler 

5from fastapi import FastAPI, Form, HTTPException, Response, Cookie, Depends 

6from fastapi.security import OAuth2PasswordBearer 

7import jwt 

8from datetime import datetime, timedelta, timezone 

9import os 

10 

11python_path = os.environ.get("PYTHONPATH") 

12log_file_path = f"{python_path}logs/trading.log" 

13file_format = "%(asctime)s - %(levelname)s - %(message)s" 

14stdout_format = "%(levelname)s: %(message)s" 

15# Create handlers 

16file_handler = RotatingFileHandler( 

17 log_file_path, 

18 maxBytes=10000, 

19 backupCount=3) 

20file_handler.setLevel(logging.INFO) 

21file_handler.setFormatter(logging.Formatter(file_format)) 

22 

23stream_handler = logging.StreamHandler() 

24stream_handler.setLevel(logging.INFO) 

25stream_handler.setFormatter(logging.Formatter(stdout_format)) 

26 

27logger = logging.getLogger(__name__) 

28logger.setLevel(logging.INFO) 

29logger.addHandler(file_handler) 

30logger.addHandler(stream_handler) 

31 

32app = FastAPI() 

33 

34 

35# Secret key to sign JWT tokens 

36SECRET_KEY = "your-secret-key" 

37ALGORITHM = "HS256" 

38 

39 

40# Token expiration time in minutes 

41ACCESS_TOKEN_EXPIRE_MINUTES = 20 

42REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days 

43oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 

44 

45 

46class Sqlite_db(): # pragma: no cove 

47 def database_create_user_table(self): 

48 with sqlite3.connect('trading_database.db') as connection: 

49 cursor = connection.cursor() 

50 cursor.execute('''CREATE TABLE IF NOT EXISTS users ( 

51 id INTEGER PRIMARY KEY, 

52 username TEXT UNIQUE, 

53 password TEXT 

54 )''') 

55 

56 def database_create_user(self, username: str, password: str): 

57 with sqlite3.connect('trading_database.db') as connection: 

58 cursor = connection.cursor() 

59 try: 

60 cursor.execute( 

61 'INSERT INTO users (username, password) VALUES (?, ?)', 

62 (username, password)) 

63 except sqlite3.IntegrityError: 

64 print("Warn: The user alredy exists in the database") 

65 

66 def database_select_user_by_username(self, username: str) -> tuple | None: 

67 with sqlite3.connect('trading_database.db') as connection: 

68 cursor = connection.cursor() 

69 cursor.execute( 

70 'SELECT * FROM users WHERE username = ?', 

71 (username,)) 

72 row = cursor.fetchone() 

73 return row 

74 

75 def exists_robinhood_profile(self, user_id: int) -> bool: 

76 with sqlite3.connect('trading_database.db') as connection: 

77 cursor = connection.cursor() 

78 sql_query = '''SELECT EXISTS(SELECT 1 FROM robinhood_profiles 

79 WHERE user_id = ? LIMIT 1)''' 

80 cursor.execute(sql_query, (user_id,)) 

81 exists = cursor.fetchone()[0] 

82 

83 if exists: 

84 logger.info('A robinhood profile with a user_id of ' 

85 f'{user_id} is in our database already.') 

86 return True 

87 logger.info('System did not find a robinhood profile in our database ' 

88 f'with a user_id of {user_id}.') 

89 return False 

90 

91 def update_robinhood_profile( 

92 self, 

93 robinhood_username: str, 

94 robinhood_password: str, 

95 user_id: int) -> bool: 

96 logger.info('Request to update the username of a the robinhood ' 

97 f'profile for user_id {user_id} opened.') 

98 

99 return False 

100 

101 def create_robinhood_profile( 

102 self, 

103 robinhood_username: str, 

104 robinhood_password: str, 

105 user_id: int) -> bool: 

106 logger.info('Request to create a robinhood profile for the user_id ' 

107 f'{user_id} opened.') 

108 with sqlite3.connect('trading_database.db') as connection: 

109 sql_query = '''INSERT INTO robinhood_profiles 

110(username, password, user_id) VALUES (?, ?, ?);''' 

111 cursor = connection.cursor() 

112 cursor.execute( 

113 sql_query, 

114 ( 

115 robinhood_username, 

116 robinhood_password, 

117 user_id 

118 ) 

119 ) 

120 if cursor.rowcount == 1: 

121 logger.info('Request to create a robinhood profile for the ' 

122 f'user_id {user_id} completed and closed.') 

123 return True 

124 return False 

125 

126 

127def validate_user(user_record: tuple, username: str, password: str) -> bool: 

128 if username != user_record[1]: 

129 return False 

130 if password != user_record[2]: 

131 return False 

132 return True 

133 

134 

135# Function to create access token 

136def create_access_token(data: dict, expires_delta: timedelta): 

137 to_encode = data.copy() 

138 expire = datetime.utcnow() + expires_delta 

139 to_encode.update({"exp": expire}) 

140 encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 

141 return encoded_jwt 

142 

143 

144def auth_user(session_token: str = Cookie(None, alias='access_token')): 

145 credentials_exceptions = HTTPException( 

146 status_code=401, 

147 detail="Could not validate credentials.", 

148 headers={"WWW-Authenticate": "Bearer"}) 

149 credentials_expired_exceptions = HTTPException( 

150 status_code=401, 

151 detail="Token has expired.", 

152 headers={"WWW-Authenticate": "Bearer"}) 

153 credentials_audience_exceptions = HTTPException( 

154 status_code=401, 

155 detail="Token was provided to a invalid audience.", 

156 headers={"WWW-Authenticate": "Bearer"}) 

157 credentials_issuer_exceptions = HTTPException( 

158 status_code=401, 

159 detail="Token is from a invalid issuer.", 

160 headers={"WWW-Authenticate": "Bearer"}) 

161 

162 if not session_token: 

163 logger.info("auth_user: No session token was provided by requester.") 

164 raise credentials_exceptions 

165 try: 

166 payload = jwt.decode( 

167 session_token, 

168 SECRET_KEY, 

169 audience="tradingapi", 

170 issuer="tradingapi", 

171 algorithms=[ALGORITHM]) 

172 logger.info("auth_user: The token provided was a valid token") 

173 return payload 

174 except jwt.ExpiredSignatureError: 

175 logger.warning(( 

176 "auth_user: The session token provided", 

177 " by the requester has expired.")) 

178 raise credentials_expired_exceptions 

179 except jwt.InvalidAudienceError: 

180 logger.warning(( 

181 "auth_user: The session token was", 

182 " provided to a invalid audience.")) 

183 raise credentials_audience_exceptions 

184 except jwt.InvalidIssuerError: 

185 logger.warning(( 

186 "auth_user: The session token", 

187 " is from a invalid issuer.")) 

188 raise credentials_issuer_exceptions 

189 except jwt.InvalidTokenError: 

190 logger.warning(( 

191 "auth_user: The session token provided", 

192 " by the requester is not valid.")) 

193 raise credentials_exceptions 

194 

195 

196@app.get("/") 

197def home(): 

198 return "Hello World" 

199 

200 

201@app.post("/login") 

202async def login( 

203 response: Response, 

204 username: str = Form(), 

205 password: str = Form()): 

206 db = Sqlite_db() 

207 user_record = db.database_select_user_by_username(username) 

208 if user_record and validate_user(user_record, username, password): 

209 access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) 

210 cookie_expires = datetime.utcnow() + access_token_expires 

211 cookie_expires_utc = cookie_expires.replace(tzinfo=timezone.utc) 

212 access_token = create_access_token( 

213 data={ 

214 "sub": username, 

215 "iss": "tradingapi", 

216 "aud": "tradingapi"}, 

217 expires_delta=access_token_expires) 

218 response.set_cookie( 

219 key="access_token", 

220 value=access_token, 

221 httponly=True, 

222 secure=True, 

223 expires=cookie_expires_utc) 

224 return "User has been logged in" 

225 raise HTTPException(status_code=401, detail="Invalid credentials") 

226 

227 

228@app.get("/protected-route") 

229async def protected_route(current_user: dict = Depends(auth_user)): 

230 return {"message": "This is a protected route", "user": current_user} 

231 

232 

233@app.post("/api/v1/robinhood_profile") 

234def update_robinhood_profile( 

235 current_user: dict = Depends(auth_user), 

236 robinhood_user: str = Form(), 

237 robinhood_password: str = Form()): 

238 db = Sqlite_db() 

239 user_record = db.database_select_user_by_username(current_user['sub']) 

240 if not user_record: 

241 return HTTPException( 

242 status_code=500, 

243 detail="user was not found in db.") 

244 database_user_id = user_record[0] 

245 if db.exists_robinhood_profile(database_user_id): 

246 db.update_robinhood_profile( 

247 robinhood_user, 

248 robinhood_password, 

249 database_user_id) 

250 else: 

251 db.create_robinhood_profile( 

252 robinhood_user, 

253 robinhood_password, 

254 database_user_id) 

255 return "Robinhood user updated in profile" 

256 

257 

258if __name__ == "__main__": # pragma: no cover 

259 import uvicorn 

260 uvicorn.run(app, host="127.0.0.1", port=8000)