diff --git a/src/workos/session.py b/src/workos/session.py index 330a66bd..236a250c 100644 --- a/src/workos/session.py +++ b/src/workos/session.py @@ -336,18 +336,11 @@ def refresh( ) try: - # Use raw dict request because the generated AuthenticateResponse - # doesn't include sealed_session, and the request body needs the - # session parameter which isn't in the generated request models. body: Dict[str, Any] = { "grant_type": "refresh_token", "client_id": self._client.client_id, "client_secret": self._client._api_key, "refresh_token": session["refresh_token"], - "session": { - "seal_session": True, - "cookie_password": effective_cookie_password, - }, } if organization_id is not None: body["organization_id"] = organization_id @@ -358,31 +351,64 @@ def refresh( body=body, ) - self.session_data = str(auth_response["sealed_session"]) - self.cookie_password = effective_cookie_password - - signing_key = self.jwks.get_signing_key_from_jwt( - auth_response["access_token"] - ) - decoded = jwt.decode( - auth_response["access_token"], - signing_key.key, - algorithms=self._JWK_ALGORITHMS, - options={"verify_aud": False}, - leeway=self._client._jwt_leeway, + access_token = auth_response.get("access_token") + refresh_token = auth_response.get("refresh_token") + if not access_token or not refresh_token: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED, + ) + + user = auth_response.get("user") or {} + impersonator = auth_response.get("impersonator") + + try: + signing_key = self.jwks.get_signing_key_from_jwt(access_token) + decoded = jwt.decode( + access_token, + signing_key.key, + algorithms=self._JWK_ALGORITHMS, + options={"verify_aud": False}, + leeway=self._client._jwt_leeway, + ) + except ( + jwt.exceptions.InvalidTokenError, + jwt.exceptions.PyJWKClientError, + ): + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, + ) + + session_id = decoded.get("sid") + if not session_id: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, + ) + + new_sealed = seal_session_from_auth_response( + access_token=access_token, + refresh_token=refresh_token, + user=user, + impersonator=impersonator, + cookie_password=effective_cookie_password, ) + self.session_data = new_sealed + self.cookie_password = effective_cookie_password + return RefreshWithSessionCookieSuccessResponse( authenticated=True, - sealed_session=str(auth_response["sealed_session"]), - session_id=decoded["sid"], + sealed_session=new_sealed, + session_id=session_id, organization_id=decoded.get("org_id"), role=decoded.get("role"), roles=decoded.get("roles"), permissions=decoded.get("permissions"), entitlements=decoded.get("entitlements"), - user=auth_response.get("user"), - impersonator=auth_response.get("impersonator"), + user=user, + impersonator=impersonator, feature_flags=decoded.get("feature_flags"), ) except Exception as e: @@ -523,10 +549,6 @@ async def refresh( "client_id": self._client.client_id, "client_secret": self._client._api_key, "refresh_token": session["refresh_token"], - "session": { - "seal_session": True, - "cookie_password": effective_cookie_password, - }, } if organization_id is not None: body["organization_id"] = organization_id @@ -537,31 +559,64 @@ async def refresh( body=body, ) - self.session_data = str(auth_response["sealed_session"]) - self.cookie_password = effective_cookie_password - - signing_key = self.jwks.get_signing_key_from_jwt( - auth_response["access_token"] - ) - decoded = jwt.decode( - auth_response["access_token"], - signing_key.key, - algorithms=self._JWK_ALGORITHMS, - options={"verify_aud": False}, - leeway=self._client._jwt_leeway, + access_token = auth_response.get("access_token") + refresh_token = auth_response.get("refresh_token") + if not access_token or not refresh_token: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED, + ) + + user = auth_response.get("user") or {} + impersonator = auth_response.get("impersonator") + + try: + signing_key = self.jwks.get_signing_key_from_jwt(access_token) + decoded = jwt.decode( + access_token, + signing_key.key, + algorithms=self._JWK_ALGORITHMS, + options={"verify_aud": False}, + leeway=self._client._jwt_leeway, + ) + except ( + jwt.exceptions.InvalidTokenError, + jwt.exceptions.PyJWKClientError, + ): + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, + ) + + session_id = decoded.get("sid") + if not session_id: + return RefreshWithSessionCookieErrorResponse( + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, + ) + + new_sealed = seal_session_from_auth_response( + access_token=access_token, + refresh_token=refresh_token, + user=user, + impersonator=impersonator, + cookie_password=effective_cookie_password, ) + self.session_data = new_sealed + self.cookie_password = effective_cookie_password + return RefreshWithSessionCookieSuccessResponse( authenticated=True, - sealed_session=str(auth_response["sealed_session"]), - session_id=decoded["sid"], + sealed_session=new_sealed, + session_id=session_id, organization_id=decoded.get("org_id"), role=decoded.get("role"), roles=decoded.get("roles"), permissions=decoded.get("permissions"), entitlements=decoded.get("entitlements"), - user=auth_response.get("user"), - impersonator=auth_response.get("impersonator"), + user=user, + impersonator=impersonator, feature_flags=decoded.get("feature_flags"), ) except Exception as e: diff --git a/tests/test_session.py b/tests/test_session.py index 15e44e78..a1360029 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -26,6 +26,7 @@ AuthenticateWithSessionCookieFailureReason, AuthenticateWithSessionCookieSuccessResponse, RefreshWithSessionCookieErrorResponse, + RefreshWithSessionCookieSuccessResponse, Session, _map_refresh_exception_to_reason, seal_data, @@ -225,6 +226,271 @@ def test_session_refresh_missing_refresh_token(self): result = session.refresh() assert isinstance(result, RefreshWithSessionCookieErrorResponse) + def test_session_refresh_success(self): + new_token = _make_jwt(self.private_key) + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session.jwks = self._mock_jwks() + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01", "email": "test@example.com"}, + "authentication_method": "Password", + } + session._client.request_raw = MagicMock(return_value=api_response) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieSuccessResponse) + assert result.authenticated + assert result.session_id == "session_01" + assert result.organization_id == "org_01" + assert result.user == {"id": "user_01", "email": "test@example.com"} + + unsealed = unseal_data(result.sealed_session, COOKIE_PASSWORD) + assert unsealed["access_token"] == new_token + assert unsealed["refresh_token"] == "rt_new" + assert unsealed["user"]["id"] == "user_01" + + assert session.session_data == result.sealed_session + + def test_session_refresh_seals_client_side_without_sealed_session_in_response(self): + """Regression: API response never contains sealed_session; the SDK must seal locally.""" + new_token = _make_jwt(self.private_key) + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session.jwks = self._mock_jwks() + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + "authentication_method": "Password", + } + session._client.request_raw = MagicMock(return_value=api_response) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieSuccessResponse) + assert result.sealed_session + unsealed = unseal_data(result.sealed_session, COOKIE_PASSWORD) + assert unsealed["access_token"] == new_token + assert unsealed["refresh_token"] == "rt_new" + + def test_session_refresh_with_impersonator(self): + new_token = _make_jwt(self.private_key) + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session.jwks = self._mock_jwks() + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + "impersonator": {"email": "admin@example.com"}, + "authentication_method": "Password", + } + session._client.request_raw = MagicMock(return_value=api_response) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieSuccessResponse) + assert result.impersonator == {"email": "admin@example.com"} + unsealed = unseal_data(result.sealed_session, COOKIE_PASSWORD) + assert unsealed["impersonator"]["email"] == "admin@example.com" + + def test_session_refresh_does_not_send_session_param(self): + """The session/seal_session param should not be sent to the API.""" + new_token = _make_jwt(self.private_key) + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session.jwks = self._mock_jwks() + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + "authentication_method": "Password", + } + session._client.request_raw = MagicMock(return_value=api_response) + + session.refresh() + + call_kwargs = session._client.request_raw.call_args + sent_body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body") + assert "session" not in sent_body + + def test_session_refresh_round_trip_authenticate(self): + """The sealed cookie produced by refresh() must be re-authenticatable.""" + new_token = _make_jwt(self.private_key) + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session.jwks = self._mock_jwks() + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01", "email": "test@example.com"}, + "authentication_method": "Password", + } + session._client.request_raw = MagicMock(return_value=api_response) + + refresh_result = session.refresh() + assert isinstance(refresh_result, RefreshWithSessionCookieSuccessResponse) + + new_session = Session( + client=self.workos, + session_data=refresh_result.sealed_session, + cookie_password=COOKIE_PASSWORD, + ) + new_session.jwks = self._mock_jwks() + + auth_result = new_session.authenticate() + assert isinstance(auth_result, AuthenticateWithSessionCookieSuccessResponse) + assert auth_result.authenticated + assert auth_result.session_id == "session_01" + assert auth_result.user == {"id": "user_01", "email": "test@example.com"} + + def test_session_refresh_maps_auth_error_to_refresh_denied(self): + """AuthenticationError from request_raw maps to REFRESH_DENIED via except.""" + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session._client.request_raw = MagicMock( + side_effect=AuthenticationError("unauthorized") + ) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert not result.authenticated + assert ( + result.reason == AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED + ) + + def test_session_refresh_missing_access_token_returns_refresh_denied(self): + """A malformed 2xx response missing access_token returns REFRESH_DENIED.""" + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session._client.request_raw = MagicMock(return_value={}) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert not result.authenticated + assert ( + result.reason == AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED + ) + + def test_session_refresh_missing_sid_returns_invalid_jwt(self): + """A JWT without sid returns INVALID_JWT and preserves prior session.""" + no_sid_token = _make_jwt(self.private_key, claims={"sid": ""}) + original_sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, + session_data=original_sealed, + cookie_password=COOKIE_PASSWORD, + ) + session.jwks = self._mock_jwks() + + api_response = { + "access_token": no_sid_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + } + session._client.request_raw = MagicMock(return_value=api_response) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert result.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + assert session.session_data == original_sealed + + def test_session_refresh_invalid_jwt_returns_invalid_jwt(self): + """A JWT with bad signature maps to INVALID_JWT via the decode guard.""" + _, other_public_key = _generate_rsa_key_pair() + wrong_key_private, _ = _generate_rsa_key_pair() + bad_token = _make_jwt(wrong_key_private) + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + original_sealed = sealed + session = Session( + client=self.workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + mock_jwks = MagicMock() + mock_signing_key = MagicMock() + mock_signing_key.key = other_public_key + mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key + session.jwks = mock_jwks + + api_response = { + "access_token": bad_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + } + session._client.request_raw = MagicMock(return_value=api_response) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert result.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + assert session.session_data == original_sealed + + def test_session_refresh_jwks_error_returns_invalid_jwt(self): + """A JWKS lookup failure maps to INVALID_JWT, not a raw string.""" + import jwt as pyjwt_lib + + new_token = _make_jwt(self.private_key) + original_sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = Session( + client=self.workos, + session_data=original_sealed, + cookie_password=COOKIE_PASSWORD, + ) + mock_jwks = MagicMock() + mock_jwks.get_signing_key_from_jwt.side_effect = ( + pyjwt_lib.exceptions.PyJWKClientError("Unable to find a signing key") + ) + session.jwks = mock_jwks + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + } + session._client.request_raw = MagicMock(return_value=api_response) + + result = session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert result.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + assert session.session_data == original_sealed + class TestMapRefreshExceptionToReason: @pytest.mark.parametrize( @@ -314,3 +580,168 @@ async def test_async_session_authenticate_success(self, async_workos): result = session.authenticate() assert isinstance(result, AuthenticateWithSessionCookieSuccessResponse) assert result.session_id == "session_01" + + async def test_async_session_refresh_success(self, async_workos): + from unittest.mock import AsyncMock + + private_key, public_key = _generate_rsa_key_pair() + new_token = _make_jwt(private_key) + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = AsyncSession( + client=async_workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session.jwks = self._mock_jwks(public_key) + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01", "email": "test@example.com"}, + "authentication_method": "Password", + } + session._client.request_raw = AsyncMock(return_value=api_response) + + result = await session.refresh() + assert isinstance(result, RefreshWithSessionCookieSuccessResponse) + assert result.authenticated + assert result.session_id == "session_01" + + unsealed = unseal_data(result.sealed_session, COOKIE_PASSWORD) + assert unsealed["access_token"] == new_token + assert unsealed["refresh_token"] == "rt_new" + + async def test_async_session_refresh_maps_auth_error_to_refresh_denied( + self, async_workos + ): + from unittest.mock import AsyncMock + + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = AsyncSession( + client=async_workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session._client.request_raw = AsyncMock( + side_effect=AuthenticationError("unauthorized") + ) + + result = await session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert not result.authenticated + assert ( + result.reason == AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED + ) + + async def test_async_session_refresh_missing_access_token_returns_refresh_denied( + self, async_workos + ): + from unittest.mock import AsyncMock + + sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = AsyncSession( + client=async_workos, session_data=sealed, cookie_password=COOKIE_PASSWORD + ) + session._client.request_raw = AsyncMock(return_value={}) + + result = await session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert not result.authenticated + assert ( + result.reason == AuthenticateWithSessionCookieFailureReason.REFRESH_DENIED + ) + + async def test_async_session_refresh_missing_sid_returns_invalid_jwt( + self, async_workos + ): + from unittest.mock import AsyncMock + + private_key, public_key = _generate_rsa_key_pair() + no_sid_token = _make_jwt(private_key, claims={"sid": ""}) + original_sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = AsyncSession( + client=async_workos, + session_data=original_sealed, + cookie_password=COOKIE_PASSWORD, + ) + session.jwks = self._mock_jwks(public_key) + + api_response = { + "access_token": no_sid_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + } + session._client.request_raw = AsyncMock(return_value=api_response) + + result = await session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert result.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + assert session.session_data == original_sealed + + async def test_async_session_refresh_invalid_jwt_returns_invalid_jwt( + self, async_workos + ): + from unittest.mock import AsyncMock + + sign_private, _ = _generate_rsa_key_pair() + _, verify_public = _generate_rsa_key_pair() + bad_token = _make_jwt(sign_private) + original_sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = AsyncSession( + client=async_workos, + session_data=original_sealed, + cookie_password=COOKIE_PASSWORD, + ) + session.jwks = self._mock_jwks(verify_public) + + api_response = { + "access_token": bad_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + } + session._client.request_raw = AsyncMock(return_value=api_response) + + result = await session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert result.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + assert session.session_data == original_sealed + + async def test_async_session_refresh_jwks_error_returns_invalid_jwt( + self, async_workos + ): + import jwt as pyjwt_lib + from unittest.mock import AsyncMock + + private_key, _ = _generate_rsa_key_pair() + new_token = _make_jwt(private_key) + original_sealed = seal_data( + {"refresh_token": "rt_old", "user": {"id": "user_01"}}, COOKIE_PASSWORD + ) + session = AsyncSession( + client=async_workos, + session_data=original_sealed, + cookie_password=COOKIE_PASSWORD, + ) + mock_jwks = MagicMock() + mock_jwks.get_signing_key_from_jwt.side_effect = ( + pyjwt_lib.exceptions.PyJWKClientError("Unable to find a signing key") + ) + session.jwks = mock_jwks + + api_response = { + "access_token": new_token, + "refresh_token": "rt_new", + "user": {"id": "user_01"}, + } + session._client.request_raw = AsyncMock(return_value=api_response) + + result = await session.refresh() + assert isinstance(result, RefreshWithSessionCookieErrorResponse) + assert result.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + assert session.session_data == original_sealed