Coverage for python3/tests/test_restAPI.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-01-10 02:24 +0000

1from fastapi.testclient import TestClient 

2from src.restapi import Sqlite_db, app, SECRET_KEY, ALGORITHM 

3from datetime import timedelta, datetime 

4import jwt 

5 

6 

7client = TestClient(app) 

8''' 

9def test_robinhood_user(): 

10 valid_token = jwt.encode( 

11 { 

12 "sub": "testuser", 

13 "iss": "tradingapi", 

14 "aud": "tradingapi" 

15 }, 

16 SECRET_KEY, 

17 algorithm=ALGORITHM) 

18 response = client.post( 

19 "/api/v1/set_robinhood_user", 

20 data={"token": valid_token}, 

21 cookies={"access_token": valid_token}) 

22 

23 assert response.status_code == 200 

24''' 

25def test_protected_route_with_valid_token(): 

26 # Generate a valid token for testing 

27 valid_token = jwt.encode( 

28 { 

29 "sub": "testuser", 

30 "iss": "tradingapi", 

31 "aud": "tradingapi" 

32 }, 

33 SECRET_KEY, 

34 algorithm=ALGORITHM) 

35 

36 # Make a request with the valid token 

37 response = client.get( 

38 "/protected-route", 

39 cookies={"access_token": valid_token}) 

40 

41 # Assert that the response status code is 200 

42 assert response.status_code == 200 

43 

44 # Assert that the response 

45 # contains the expected message and user information 

46 assert response.json() == { 

47 "message": "This is a protected route", 

48 "user": { 

49 "sub": "testuser", 

50 "iss": "tradingapi", 

51 "aud": "tradingapi", 

52 }} 

53 

54 

55def test_protected_route_with_expired_token(): 

56 # Generate a valid token for testing 

57 access_token_expires = timedelta(minutes=-5) 

58 expire = datetime.utcnow() + access_token_expires 

59 valid_token = jwt.encode( 

60 { 

61 "sub": "testuser", 

62 "iss": "tradingapi", 

63 "aud": "tradingapi", 

64 "exp": expire 

65 }, 

66 SECRET_KEY, 

67 algorithm=ALGORITHM) 

68 

69 # Make a request with the valid token 

70 response = client.get( 

71 "/protected-route", 

72 cookies={"access_token": valid_token}) 

73 

74 # Assert that the response status code is 200 

75 assert response.status_code == 401 

76 

77 # Assert that the response 

78 # contains the expected message and user information 

79 assert response.json() == { 

80 "detail": "Token has expired.", 

81 } 

82 

83 

84def test_protected_route_with_invalid_iss(): 

85 # Generate a valid token for testing 

86 access_token_expires = timedelta(minutes=5) 

87 expire = datetime.utcnow() + access_token_expires 

88 valid_token = jwt.encode( 

89 { 

90 "sub": "testuser", 

91 "iss": "evil_iss", 

92 "aud": "tradingapi", 

93 "exp": expire 

94 }, 

95 SECRET_KEY, 

96 algorithm=ALGORITHM) 

97 

98 # Make a request with the valid token 

99 response = client.get( 

100 "/protected-route", 

101 cookies={"access_token": valid_token}) 

102 

103 # Assert that the response status code is 200 

104 assert response.status_code == 401 

105 

106 # Assert that the response 

107 # contains the expected message and user information 

108 assert response.json() == { 

109 "detail": "Token is from a invalid issuer.", 

110 } 

111 

112 

113def test_protected_route_with_invalid_aud(): 

114 # Generate a valid token for testing 

115 access_token_expires = timedelta(minutes=5) 

116 expire = datetime.utcnow() + access_token_expires 

117 valid_token = jwt.encode( 

118 { 

119 "sub": "testuser", 

120 "iss": "tradingapi", 

121 "aud": "evil_aud", 

122 "exp": expire 

123 }, 

124 SECRET_KEY, 

125 algorithm=ALGORITHM) 

126 

127 # Make a request with the valid token 

128 response = client.get( 

129 "/protected-route", 

130 cookies={"access_token": valid_token}) 

131 

132 # Assert that the response status code is 200 

133 assert response.status_code == 401 

134 

135 # Assert that the response 

136 # contains the expected message and user information 

137 assert response.json() == { 

138 "detail": "Token was provided to a invalid audience.", 

139 } 

140 

141 

142def test_protected_route_with_no_token(): 

143 # Make a request with an invalid token 

144 response = client.get( 

145 "/protected-route") 

146 

147 # Assert that the response status code is 401 

148 assert response.status_code == 401 

149 

150 # Assert that the response contains the expected error message 

151 assert response.json() == {"detail": "Could not validate credentials."} 

152 

153 

154def test_protected_route_with_invalid_token(): 

155 # Make a request with an invalid token 

156 response = client.get( 

157 "/protected-route", 

158 cookies={"access_token": "invalid_token"}) 

159 

160 # Assert that the response status code is 401 

161 assert response.status_code == 401 

162 

163 # Assert that the response contains the expected error message 

164 assert response.json() == {"detail": "Could not validate credentials."} 

165 

166 

167def test_root(): 

168 response = client.get("/") 

169 assert response.status_code == 200 

170 assert response.json() == "Hello World" 

171 

172 

173def test_login_successful(monkeypatch): 

174 # Override the default database with the mock database 

175 monkeypatch.setattr( 

176 Sqlite_db, 

177 'database_select_user_by_username', 

178 lambda x, y: (1, "existing_user", "hashed_password")) 

179 

180 # Send a POST request to the login endpoint with valid credentials 

181 response = client.post( 

182 "/login", 

183 data={"username": "existing_user", "password": "hashed_password"}, 

184 ) 

185 

186 # Assert the response 

187 assert response.status_code == 200 

188 assert "access_token" in response.cookies 

189 

190 # Reset dependency overrides to avoid affecting other tests 

191 app.dependency_overrides = {} 

192 

193 

194def test_login_invalid_username(monkeypatch): 

195 # Override the default database with the mock database 

196 monkeypatch.setattr( 

197 Sqlite_db, 

198 'database_select_user_by_username', 

199 lambda x, y: (1, "existing_user", "hashed_password")) 

200 

201 # Send a POST request to the login endpoint with valid credentials 

202 response = client.post( 

203 "/login", 

204 data={"username": "bad_user", "password": "hashed_password"}, 

205 ) 

206 

207 # Assert the response 

208 assert response.status_code == 401 

209 assert "access_token" not in response.cookies 

210 

211 # Reset dependency overrides to avoid affecting other tests 

212 app.dependency_overrides = {} 

213 

214 

215def test_login_invalid_password(monkeypatch): 

216 # Override the default database with the mock database 

217 monkeypatch.setattr( 

218 Sqlite_db, 

219 'database_select_user_by_username', 

220 lambda x, y: (1, "existing_user", "hashed_password")) 

221 

222 # Send a POST request to the login endpoint with valid credentials 

223 response = client.post( 

224 "/login", 

225 data={"username": "existing_user", "password": "bad_password"}, 

226 ) 

227 

228 # Assert the response 

229 assert response.status_code == 401 

230 assert "access_token" not in response.cookies 

231 

232 # Reset dependency overrides to avoid affecting other tests 

233 app.dependency_overrides = {}