Coverage for src/gitlabracadabra/containers/authenticated_session.py: 85%

111 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-14 23:10 +0200

1# 

2# Copyright (C) 2019-2025 Mathieu Parent <math.parent@gmail.com> 

3# 

4# This program is free software: you can redistribute it and/or modify 

5# it under the terms of the GNU Lesser General Public License as published by 

6# the Free Software Foundation, either version 3 of the License, or 

7# (at your option) any later version. 

8# 

9# This program is distributed in the hope that it will be useful, 

10# but WITHOUT ANY WARRANTY; without even the implied warranty of 

11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

12# GNU Lesser General Public License for more details. 

13# 

14# You should have received a copy of the GNU Lesser General Public License 

15# along with this program. If not, see <http://www.gnu.org/licenses/>. 

16 

17from __future__ import annotations 

18 

19from time import time 

20from typing import TYPE_CHECKING 

21from urllib.parse import urlparse 

22from urllib.request import parse_http_list, parse_keqv_list 

23 

24from requests import PreparedRequest, Response, codes 

25from requests.structures import CaseInsensitiveDict 

26 

27from gitlabracadabra import __version__ as gitlabracadabra_version 

28from gitlabracadabra.auth_info import AuthInfo 

29from gitlabracadabra.session import Session 

30 

31if TYPE_CHECKING: 31 ↛ 32line 31 didn't jump to line 32 because the condition on line 31 was never true

32 from collections.abc import Iterable, MutableMapping 

33 from typing import Any 

34 

35 from requests.auth import AuthBase 

36 

37 from gitlabracadabra.containers.scope import Scope 

38 

39 Params = ( 

40 MutableMapping[ 

41 str, 

42 str | list[str], 

43 ] 

44 | None 

45 ) 

46 Data = Iterable[bytes] 

47 _SimpleParams = dict[str, str | list[str]] 

48 _TokenKey = tuple[str, str, int | None, str | None] 

49 

50 

51class Token: 

52 """JWT Token.""" 

53 

54 def __init__( 

55 self, 

56 token: str, 

57 expires_in: int, 

58 ) -> None: 

59 """Instantiate a token. 

60 

61 Args: 

62 token: Token. 

63 expires_in: Expires in x seconds. 

64 """ 

65 minimum_token_lifetime_seconds = 60 

66 

67 self._token = token 

68 self._expires_in = expires_in 

69 if self._expires_in < minimum_token_lifetime_seconds: 

70 self._expires_in = minimum_token_lifetime_seconds 

71 

72 # We ignore issued_at property, and use local time instead 

73 self._issued_at = time() 

74 

75 @property 

76 def token(self) -> str: 

77 """Get token. 

78 

79 Returns: 

80 The token. 

81 """ 

82 return self._token 

83 

84 @property 

85 def expiration_time(self) -> float: 

86 """Get expiration time. 

87 

88 Returns: 

89 Expiration time. 

90 """ 

91 return self._issued_at + self._expires_in 

92 

93 def is_expired(self) -> bool: 

94 """Check if token is expired. 

95 

96 Returns: 

97 True if token is expired. 

98 """ 

99 return time() >= self.expiration_time 

100 

101 

102class AuthenticatedSession(Session): 

103 """Session with auth per-host.""" 

104 

105 def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None: 

106 """Instantiate a session. 

107 

108 Args: 

109 args: Positional arguments. 

110 kwargs: Named arguments. 

111 """ 

112 super().__init__(*args, **kwargs) 

113 self.headers = CaseInsensitiveDict( 

114 { 

115 "User-Agent": f"gitlabracadabra/{gitlabracadabra_version}", 

116 "Docker-Distribution-Api-Version": "registry/2.0", 

117 } 

118 ) 

119 

120 # Added attributes 

121 self.scheme = "https" 

122 self.connection_hostname = "" 

123 self.auth_info = AuthInfo() 

124 # Tokens, by set of scheme, host, port and scopes (as query string or None for all scope) 

125 self._tokens: dict[_TokenKey, Token] = {} 

126 self._current_scopes: set[Scope] | None = None 

127 

128 def authenticated_request( 

129 self, 

130 method: str, 

131 url: str, 

132 params: Params | None = None, 

133 data: Data | None = None, 

134 headers: dict[str, str] | None = None, 

135 auth: AuthBase | None = None, 

136 stream: bool | None = None, 

137 ) -> Response: 

138 """Send an HTTP request. 

139 

140 Args: 

141 method: HTTP method. 

142 url: Either a path or a full url. 

143 params: query string params. 

144 data: Request body stream. 

145 headers: Request headers. 

146 auth: HTTPBasicAuth. 

147 stream: Stream the response. 

148 

149 Returns: 

150 A Response. 

151 """ 

152 if url.startswith("/"): 

153 url = f"{self.scheme}://{self.connection_hostname}{url}" 

154 token = self._get_token(url, self._current_scopes) 

155 if token: 

156 if headers is None: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 headers = {} 

158 headers["Authorization"] = f"Bearer {token.token}" 

159 return self.request( 

160 method, 

161 url, 

162 params=params, 

163 data=data, 

164 headers=headers, 

165 auth=auth, 

166 stream=stream, 

167 ) 

168 

169 def rebuild_auth(self, prepared_request: PreparedRequest, response: Response) -> None: 

170 """Override Session method to inject bearer tokens. 

171 

172 Args: 

173 prepared_request: Prepared request. 

174 response: Response. 

175 """ 

176 super().rebuild_auth(prepared_request, response) # type: ignore 

177 token = self._get_token(prepared_request.url or "", self._current_scopes) 

178 if token: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true

179 prepared_request.headers["Authorization"] = f"Bearer {token.token}" 

180 

181 def connect(self, scopes: set[Scope] | None) -> None: 

182 """Connect. 

183 

184 Args: 

185 scopes: An optional set of scopes. 

186 """ 

187 self._current_scopes = scopes 

188 url = f"{self.scheme}://{self.connection_hostname}/v2/" 

189 token = self._get_token(url, scopes) 

190 if token: 

191 return 

192 token = self._get_token(url, None) 

193 if token: 

194 return 

195 response = self.authenticated_request("get", url) 

196 if response.history: 

197 self.connection_hostname = urlparse(response.url).hostname or self.connection_hostname 

198 if response.status_code == codes["ok"]: 

199 one_hour = 3600 

200 self._set_token(response, None, Token("no_auth", one_hour)) 

201 return 

202 if response.status_code == codes["unauthorized"] and response.headers["Www-Authenticate"].startswith("Bearer "): 202 ↛ 205line 202 didn't jump to line 205 because the condition on line 202 was always true

203 self._get_bearer_token(response) 

204 return 

205 response.raise_for_status() 

206 

207 def _get_bearer_token(self, response: Response) -> None: 

208 if self._current_scopes is None: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true

209 raise ValueError 

210 challenge_parameters = self._get_challenge_parameters(response) 

211 get_params: _SimpleParams = {} 

212 if "service" in challenge_parameters: 212 ↛ 214line 212 didn't jump to line 214 because the condition on line 212 was always true

213 get_params["service"] = challenge_parameters.get("service", "unknown") 

214 get_params["scope"] = [] 

215 for scope in sorted(self._current_scopes): 

216 get_params["scope"].append( # type: ignore 

217 f"repository:{scope.remote_name}:{scope.actions}", 

218 ) 

219 challenge_response = self.authenticated_request( 

220 "get", 

221 challenge_parameters["realm"], 

222 params=get_params, 

223 headers=self.auth_info.headers, 

224 auth=self.auth_info.auth, 

225 ) 

226 challenge_response.raise_for_status() 

227 json = challenge_response.json() 

228 self._set_token( 

229 response, 

230 self._current_scopes, 

231 Token( 

232 str(json.get("token", json.get("access_token", ""))), 

233 int(json.get("expires_in", 0)), 

234 ), 

235 ) 

236 

237 def _get_challenge_parameters(self, response: Response) -> dict[str, str]: 

238 _, _, challenge = response.headers["Www-Authenticate"].partition("Bearer ") 

239 return parse_keqv_list(parse_http_list(challenge)) 

240 

241 def _get_token(self, url: str, scopes: set[Scope] | None) -> Token | None: 

242 parsed = urlparse(url) 

243 key = ( 

244 parsed.scheme, 

245 parsed.hostname or "", 

246 parsed.port, 

247 self._scopes_hash(scopes), 

248 ) 

249 token = self._tokens.get(key) 

250 if token and token.is_expired(): 250 ↛ 251line 250 didn't jump to line 251 because the condition on line 250 was never true

251 self._tokens.pop(key) 

252 return None 

253 return token 

254 

255 def _set_token(self, response: Response, scopes: set[Scope] | None, token: Token) -> None: 

256 parsed = urlparse(response.url) 

257 key = ( 

258 parsed.scheme, 

259 parsed.hostname or "", 

260 parsed.port, 

261 self._scopes_hash(scopes), 

262 ) 

263 self._tokens[key] = token 

264 

265 def _scopes_hash(self, scopes: set[Scope] | None) -> str | None: 

266 if scopes is None: 

267 return None 

268 return ",".join(map(str, sorted(scopes)))