diff --git a/openedx/core/djangoapps/enrollments/serializers.py b/openedx/core/djangoapps/enrollments/serializers.py index 9b64cc95caf8..10d9ef2b3932 100644 --- a/openedx/core/djangoapps/enrollments/serializers.py +++ b/openedx/core/djangoapps/enrollments/serializers.py @@ -137,3 +137,22 @@ class Meta: model = CourseEnrollmentAllowed exclude = ["id"] lookup_field = "user" + + +class UserRoleSerializer(serializers.Serializer): # pylint: disable=abstract-method + """Serializes a single course-level role entry for a user.""" + + org = serializers.CharField() + course_id = serializers.SerializerMethodField() + role = serializers.CharField() + + def get_course_id(self, obj): + """Return course_id as a string.""" + return str(obj.course_id) + + +class UserRolesResponseSerializer(serializers.Serializer): # pylint: disable=abstract-method + """Serializes the full response payload for EnrollmentUserRolesView.""" + + roles = UserRoleSerializer(many=True) + is_staff = serializers.BooleanField() diff --git a/openedx/core/djangoapps/enrollments/tests/test_views.py b/openedx/core/djangoapps/enrollments/tests/test_views.py index 41c5a9624c24..87b475467c9f 100644 --- a/openedx/core/djangoapps/enrollments/tests/test_views.py +++ b/openedx/core/djangoapps/enrollments/tests/test_views.py @@ -36,7 +36,7 @@ from openedx.core.djangoapps.course_groups import cohorts from openedx.core.djangoapps.embargo.models import Country, CountryAccessRule, RestrictedCourse from openedx.core.djangoapps.embargo.test_utils import restrict_course -from openedx.core.djangoapps.enrollments import api, data +from openedx.core.djangoapps.enrollments import data from openedx.core.djangoapps.enrollments.errors import CourseEnrollmentError from openedx.core.djangoapps.enrollments.views import EnrollmentUserThrottle from openedx.core.djangoapps.notifications.config.waffle import ENABLE_NOTIFICATIONS @@ -711,9 +711,9 @@ def test_get_enrollment_details_bad_course(self): ) assert resp.status_code == status.HTTP_400_BAD_REQUEST - @patch.object(api, "get_enrollment") - def test_get_enrollment_internal_error(self, mock_get_enrollment): - mock_get_enrollment.side_effect = CourseEnrollmentError("Something bad happened.") + @patch.object(CourseEnrollment.objects, "get") + def test_get_enrollment_internal_error(self, mock_get): + mock_get.side_effect = CourseEnrollmentError("Something bad happened.") resp = self.client.get( reverse( 'courseenrollment', @@ -2031,3 +2031,347 @@ def test_delete_enrollment_allowed(self, delete_data, expected_result): self.client.post(self.url, self.data) response = self.client.delete(self.url, delete_data) assert response.status_code == expected_result + + # --- Response-shape tests (ADR 0025 serializer migration) --- + + def test_post_response_shape(self): + """POST 201 response contains the expected fields from CourseEnrollmentAllowedSerializer.""" + response = self.client.post(self.url, self.data) + assert response.status_code == status.HTTP_201_CREATED + body = response.json() + assert body['email'] == self.data['email'] + assert body['course_id'] == self.data['course_id'] + assert body['auto_enroll'] is False + assert 'created' in body + + def test_post_auto_enroll_true_in_response(self): + """POST with auto_enroll=true is reflected in the 201 response.""" + response = self.client.post(self.url, {**self.data, 'auto_enroll': True}) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()['auto_enroll'] is True + + def test_post_missing_email_returns_field_error(self): + """POST without email returns a serializer field-level 400 with an 'email' key.""" + response = self.client.post(self.url, {'course_id': self.data['course_id']}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert 'email' in response.json() + + def test_post_missing_course_id_returns_field_error(self): + """POST without course_id returns a serializer field-level 400 with a 'course_id' key.""" + response = self.client.post(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert 'course_id' in response.json() + + def test_post_duplicate_returns_409_with_message(self): + """A duplicate POST returns 409 with a 'message' key.""" + self.client.post(self.url, self.data) + response = self.client.post(self.url, self.data) + assert response.status_code == status.HTTP_409_CONFLICT + assert 'message' in response.json() + + def test_get_response_is_list(self): + """GET response body is a JSON list.""" + response = self.client.get(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_200_OK + assert isinstance(response.json(), list) + + def test_get_empty_response_is_empty_list(self): + """GET with no matching enrollments returns an empty list, not null.""" + response = self.client.get(self.url, {'email': 'nobody@example.com'}) + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] + + def test_get_item_shape(self): + """Each item in the GET response has the fields from CourseEnrollmentAllowedSerializer.""" + self.client.post(self.url, self.data) + response = self.client.get(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_200_OK + item = response.json()[0] + assert item['email'] == self.data['email'] + assert item['course_id'] == self.data['course_id'] + assert 'auto_enroll' in item + assert 'created' in item + + def test_get_multiple_entries_returned(self): + """GET returns all enrollment-allowed records for a given email.""" + second_course = 'course-v1:edX+OtherX+Other_Course' + self.client.post(self.url, self.data) + self.client.post(self.url, {'email': self.data['email'], 'course_id': second_course}) + response = self.client.get(self.url, {'email': self.data['email']}) + assert response.status_code == status.HTTP_200_OK + results = response.json() + assert len(results) == 2 + assert all(r['email'] == self.data['email'] for r in results) + + def test_delete_missing_email_returns_field_error(self): + """DELETE without email returns a serializer field-level 400 with an 'email' key.""" + self.client.post(self.url, self.data) + response = self.client.delete(self.url, {'course_id': self.data['course_id']}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert 'email' in response.json() + + +@skip_unless_lms +class EnrollmentViewResponseShapeTest(ModuleStoreTestCase, APITestCase): + """ + Tests that verify EnrollmentView (GET /enrollment/v1/enrollment/{course_id} and + /enrollment/v1/enrollment/{username},{course_id}) response structure is preserved + after migrating to direct serializer usage (ADR 0025). + """ + + USERNAME = "Bob" + PASSWORD = "edx" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True) + self.user = UserFactory.create(username=self.USERNAME, password=self.PASSWORD) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + CourseModeFactory.create( + course_id=self.course.id, + mode_slug=CourseMode.DEFAULT_MODE_SLUG, + mode_display_name=CourseMode.DEFAULT_MODE_SLUG, + ) + CourseEnrollment.enroll(self.user, self.course.id) + + def _get_by_course_id(self): + return self.client.get( + reverse('courseenrollment', kwargs={'course_id': str(self.course.id)}) + ) + + def _get_by_username_and_course_id(self): + return self.client.get( + reverse('courseenrollment', kwargs={'username': self.USERNAME, 'course_id': str(self.course.id)}) + ) + + def test_get_by_course_id_returns_200(self): + assert self._get_by_course_id().status_code == status.HTTP_200_OK + + def test_get_by_username_course_id_returns_200(self): + assert self._get_by_username_and_course_id().status_code == status.HTTP_200_OK + + def test_get_response_top_level_fields(self): + """Response contains the expected top-level enrollment fields.""" + body = self._get_by_course_id().json() + for field in ('created', 'mode', 'is_active', 'user', 'course_details'): + assert field in body, f"Missing top-level field: {field}" + + def test_get_response_user_and_mode(self): + """user and mode values match the enrollment.""" + body = self._get_by_course_id().json() + assert body['user'] == self.USERNAME + assert body['mode'] == CourseMode.DEFAULT_MODE_SLUG + assert body['is_active'] is True + + def test_get_by_username_course_id_matches_by_course_id(self): + """Both URL shapes return identical response bodies.""" + by_course = self._get_by_course_id().json() + by_username = self._get_by_username_and_course_id().json() + assert by_course == by_username + + def test_get_course_details_fields(self): + """course_details contains the expected nested fields.""" + course_details = self._get_by_course_id().json()['course_details'] + for field in ( + 'course_id', 'course_name', 'enrollment_start', 'enrollment_end', + 'course_start', 'course_end', 'invite_only', 'course_modes', 'pacing_type', + ): + assert field in course_details, f"Missing course_details field: {field}" + assert course_details['course_id'] == str(self.course.id) + + def test_get_no_enrollment_returns_null(self): + """GET for a course the user never enrolled in returns HTTP 200 with a null body.""" + unenrolled_course = CourseFactory.create(emit_signals=True) + resp = self.client.get( + reverse('courseenrollment', kwargs={'course_id': str(unenrolled_course.id)}) + ) + assert resp.status_code == status.HTTP_200_OK + assert resp.json() is None + + +@skip_unless_lms +class EnrollmentCourseDetailViewResponseShapeTest(ModuleStoreTestCase, APITestCase): + """ + Tests that verify EnrollmentCourseDetailView (GET /enrollment/v1/course/{course_id}) + response structure is preserved after migrating to CourseSerializer + direct ORM (ADR 0025). + """ + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True) + CourseModeFactory.create( + course_id=self.course.id, + mode_slug=CourseMode.DEFAULT_MODE_SLUG, + mode_display_name=CourseMode.DEFAULT_MODE_SLUG, + ) + + def _get_course_details(self, course_id=None, include_expired=False): + url = reverse('courseenrollmentdetails', kwargs={'course_id': course_id or str(self.course.id)}) + if include_expired: + url += '?include_expired=1' + return self.client.get(url) + + def test_returns_200(self): + assert self._get_course_details().status_code == status.HTTP_200_OK + + def test_response_top_level_fields(self): + """Response contains the expected top-level CourseSerializer fields.""" + body = self._get_course_details().json() + for field in ('course_id', 'course_name', 'enrollment_start', 'enrollment_end', + 'course_start', 'course_end', 'invite_only', 'course_modes', 'pacing_type'): + assert field in body, f"Missing field: {field}" + + def test_course_id_matches_requested_course(self): + body = self._get_course_details().json() + assert body['course_id'] == str(self.course.id) + + def test_course_modes_is_list(self): + body = self._get_course_details().json() + assert isinstance(body['course_modes'], list) + + def test_course_mode_fields(self): + """Each mode entry contains the expected fields.""" + body = self._get_course_details().json() + mode = body['course_modes'][0] + for field in ('slug', 'name', 'min_price', 'suggested_prices', 'currency', + 'expiration_datetime', 'description', 'sku', 'bulk_sku'): + assert field in mode, f"Missing course_mode field: {field}" + + def test_invalid_course_id_returns_400(self): + resp = self._get_course_details(course_id='not/a/real/course') + assert resp.status_code == status.HTTP_400_BAD_REQUEST + + def test_nonexistent_course_returns_400(self): + resp = self._get_course_details(course_id='course-v1:Org+NonExistent+2099') + assert resp.status_code == status.HTTP_400_BAD_REQUEST + + +@skip_unless_lms +class EnrollmentListViewResponseShapeTest(ModuleStoreTestCase, APITestCase): + """ + Tests that verify EnrollmentListView (GET /enrollment/v1/enrollment) + response structure is preserved after migrating to CourseEnrollmentSerializer + ORM (ADR 0025). + """ + + USERNAME = "TestLearner" + PASSWORD = "edx" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True) + CourseModeFactory.create( + course_id=self.course.id, + mode_slug=CourseMode.DEFAULT_MODE_SLUG, + mode_display_name=CourseMode.DEFAULT_MODE_SLUG, + ) + self.user = UserFactory.create(username=self.USERNAME, password=self.PASSWORD) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + CourseEnrollment.enroll(self.user, self.course.id) + + def _get_enrollments(self, user=None): + url = reverse('courseenrollments') + if user: + url += f'?user={user}' + return self.client.get(url) + + def test_returns_200(self): + assert self._get_enrollments().status_code == status.HTTP_200_OK + + def test_response_is_list(self): + body = self._get_enrollments().json() + assert isinstance(body, list) + + def test_enrollment_top_level_fields(self): + """Each enrollment entry contains the expected top-level fields.""" + body = self._get_enrollments().json() + assert len(body) >= 1 + entry = body[0] + for field in ('created', 'mode', 'is_active', 'user', 'course_details'): + assert field in entry, f"Missing top-level field: {field}" + + def test_enrollment_user_and_mode_values(self): + body = self._get_enrollments().json() + entry = body[0] + assert entry['user'] == self.USERNAME + assert entry['mode'] == CourseMode.DEFAULT_MODE_SLUG + assert entry['is_active'] is True + + def test_enrollment_course_details_fields(self): + """course_details nested object contains the expected fields.""" + body = self._get_enrollments().json() + course_details = body[0]['course_details'] + for field in ('course_id', 'course_name', 'enrollment_start', 'enrollment_end', + 'course_start', 'course_end', 'invite_only', 'course_modes'): + assert field in course_details, f"Missing course_details field: {field}" + + def test_no_enrollments_returns_empty_list(self): + """A user with no enrollments gets an empty list, not null or an error.""" + new_user = UserFactory.create(password=self.PASSWORD) + self.client.login(username=new_user.username, password=self.PASSWORD) + body = self.client.get(reverse('courseenrollments')).json() + assert body == [] + + +@skip_unless_lms +class UserRoleViewResponseShapeTest(ModuleStoreTestCase): + """ + Tests that verify EnrollmentUserRolesView (GET /enrollment/v1/roles/) + response structure is preserved after migrating to UserRolesResponseSerializer (ADR 0025). + """ + + USERNAME = "RoleTester" + PASSWORD = "edx" + + def setUp(self): + super().setUp() + self.course = CourseFactory.create(emit_signals=True, org="testorg", course="c1", run="r1") + self.user = UserFactory.create(username=self.USERNAME, password=self.PASSWORD) + self.client.login(username=self.USERNAME, password=self.PASSWORD) + + def _get_roles(self, course_id=None): + url = reverse('roles') + if course_id: + url += f'?course_id={course_id}' + return self.client.get(url) + + def test_returns_200(self): + assert self._get_roles().status_code == status.HTTP_200_OK + + def test_response_top_level_keys(self): + """Response always contains 'roles' (list) and 'is_staff' (bool).""" + body = self._get_roles().json() + assert 'roles' in body + assert 'is_staff' in body + assert isinstance(body['roles'], list) + assert isinstance(body['is_staff'], bool) + + def test_no_roles_returns_empty_list(self): + body = self._get_roles().json() + assert body['roles'] == [] + assert body['is_staff'] is False + + def test_role_entry_shape(self): + """A role entry contains org, course_id, and role fields.""" + role = CourseStaffRole(self.course.id) + role.add_users(self.user) + body = self._get_roles().json() + assert len(body['roles']) == 1 + entry = body['roles'][0] + for field in ('org', 'course_id', 'role'): + assert field in entry, f"Missing role field: {field}" + assert entry['org'] == self.course.org + assert entry['course_id'] == str(self.course.id) + + def test_is_staff_true_for_staff_user(self): + staff_user = UserFactory.create(password=self.PASSWORD, is_staff=True) + self.client.login(username=staff_user.username, password=self.PASSWORD) + body = self._get_roles().json() + assert body['is_staff'] is True + + def test_filter_by_course_id(self): + """course_id query param filters roles to that course only.""" + course2 = CourseFactory.create(emit_signals=True, org="other", course="c2", run="r2") + CourseStaffRole(self.course.id).add_users(self.user) + CourseStaffRole(course2.id).add_users(self.user) + body = self._get_roles(course_id=str(self.course.id)).json() + assert all(r['course_id'] == str(self.course.id) for r in body['roles']) diff --git a/openedx/core/djangoapps/enrollments/views.py b/openedx/core/djangoapps/enrollments/views.py index dc3423245e9b..23f8211d044b 100644 --- a/openedx/core/djangoapps/enrollments/views.py +++ b/openedx/core/djangoapps/enrollments/views.py @@ -35,19 +35,23 @@ from openedx.core.djangoapps.cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf from openedx.core.djangoapps.cors_csrf.decorators import ensure_csrf_cookie_cross_domain from openedx.core.djangoapps.course_groups.cohorts import CourseUserGroup, add_user_to_cohort, get_cohort_by_name -from openedx.core.djangoapps.embargo import api as embargo_api -from openedx.core.djangoapps.enrollments import api -from openedx.core.djangoapps.enrollments.errors import ( +from openedx.core.djangoapps.content.course_overviews.models import CourseOverview # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.embargo import api as embargo_api # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments import api # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments.errors import ( # lint-amnesty, pylint: disable=wrong-import-order CourseEnrollmentError, CourseEnrollmentExistsError, CourseModeNotFoundError, InvalidEnrollmentAttribute, ) -from openedx.core.djangoapps.enrollments.forms import CourseEnrollmentsApiListForm -from openedx.core.djangoapps.enrollments.paginators import CourseEnrollmentsApiListPagination -from openedx.core.djangoapps.enrollments.serializers import ( +from openedx.core.djangoapps.enrollments.forms import CourseEnrollmentsApiListForm # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments.paginators import CourseEnrollmentsApiListPagination # lint-amnesty, pylint: disable=wrong-import-order +from openedx.core.djangoapps.enrollments.serializers import ( # lint-amnesty, pylint: disable=wrong-import-order CourseEnrollmentAllowedSerializer, + CourseEnrollmentSerializer, CourseEnrollmentsApiListSerializer, + CourseSerializer, + UserRolesResponseSerializer, ) from openedx.core.djangoapps.user_api.accounts.permissions import CanRetireUser from openedx.core.djangoapps.user_api.models import UserRetirementStatus @@ -187,6 +191,7 @@ class EnrollmentView(APIView, ApiKeyPermissionMixIn): ) permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,) throttle_classes = (EnrollmentUserThrottle,) + serializer_class = CourseEnrollmentSerializer # Since the course about page on the marketing site uses this API to auto-enroll users, # we need to support cross-domain CSRF. @@ -221,7 +226,17 @@ def get(self, request, course_id=None, username=None): return Response(status=status.HTTP_404_NOT_FOUND) try: - return Response(api.get_enrollment(username, course_id)) + course_key = CourseKey.from_string(course_id) + except InvalidKeyError: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": f"No course '{course_id}' found for enrollment"}, + ) + + try: + enrollment = CourseEnrollment.objects.get(user__username=username, course_id=course_key) + except CourseEnrollment.DoesNotExist: + return Response(None) except CourseEnrollmentError: return Response( status=status.HTTP_400_BAD_REQUEST, @@ -233,6 +248,9 @@ def get(self, request, course_id=None, username=None): }, ) + serializer = CourseEnrollmentSerializer(enrollment) + return Response(serializer.data) + class EnrollmentUserRolesView(APIView): """ @@ -266,6 +284,7 @@ class EnrollmentUserRolesView(APIView): ) permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,) throttle_classes = (EnrollmentUserThrottle,) + serializer_class = UserRolesResponseSerializer @method_decorator(ensure_csrf_cookie_cross_domain) def get(self, request): @@ -286,14 +305,11 @@ def get(self, request): ) }, ) - return Response( - { - "roles": [ - {"org": role.org, "course_id": str(role.course_id), "role": role.role} for role in roles_data - ], - "is_staff": request.user.is_staff, - } - ) + serializer = UserRolesResponseSerializer({ + "roles": list(roles_data), + "is_staff": request.user.is_staff, + }) + return Response(serializer.data) @can_disable_rate_limit @@ -363,6 +379,7 @@ class EnrollmentCourseDetailView(APIView): authentication_classes = [] permission_classes = [] throttle_classes = (EnrollmentUserThrottle,) + serializer_class = CourseSerializer def get(self, request, course_id=None): """Read enrollment information for a particular course. @@ -380,12 +397,22 @@ def get(self, request, course_id=None): """ try: - return Response(api.get_course_enrollment_details(course_id, bool(request.GET.get("include_expired", "")))) - except CourseNotFoundError: + course_key = CourseKey.from_string(course_id) + except InvalidKeyError: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"message": f"No course found for course ID '{course_id}'"}, + ) + try: + course_overview = CourseOverview.get_from_id(course_key) + except CourseOverview.DoesNotExist: return Response( status=status.HTTP_400_BAD_REQUEST, - data={"message": ("No course found for course ID '{course_id}'").format(course_id=course_id)}, + data={"message": f"No course found for course ID '{course_id}'"}, ) + include_expired = bool(request.GET.get("include_expired", "")) + serializer = CourseSerializer(course_overview, include_expired=include_expired) + return Response(serializer.data) class UnenrollmentView(APIView): @@ -428,6 +455,7 @@ class UnenrollmentView(APIView): permissions.IsAuthenticated, CanRetireUser, ) + serializer_class = CourseEnrollmentSerializer def post(self, request): """ @@ -438,9 +466,10 @@ def post(self, request): username = request.data["username"] # Ensure that a retirement request status row exists for this username. UserRetirementStatus.get_retirement_for_retirement_action(username) - enrollments = api.get_enrollments(username) - active_enrollments = [enrollment for enrollment in enrollments if enrollment["is_active"]] - if len(active_enrollments) < 1: + active_enrollments = CourseEnrollment.objects.filter( + user__username=username, is_active=True + ) + if not active_enrollments.exists(): return Response(status=status.HTTP_204_NO_CONTENT) return Response(api.unenroll_user_from_all_courses(username)) except KeyError: @@ -633,6 +662,7 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): ) permission_classes = (ApiKeyHeaderPermissionIsAuthenticated,) throttle_classes = (EnrollmentUserThrottle,) + serializer_class = CourseEnrollmentSerializer # Since the course about page on the marketing site # uses this API to auto-enroll users, we need to support @@ -656,29 +686,22 @@ def get(self, request): courses. """ username = request.GET.get("user", request.user.username) - try: - enrollment_data = api.get_enrollments(username) - except CourseEnrollmentError: - return Response( - status=status.HTTP_400_BAD_REQUEST, - data={ - "message": ("An error occurred while retrieving enrollments for user '{username}'").format( - username=username - ) - }, - ) + enrollments = CourseEnrollment.objects.filter( + user__username=username + ).select_related("user", "course_overview") if ( username == request.user.username or GlobalStaff().has_user(request.user) or self.has_api_key_permissions(request) ): - return Response(enrollment_data) - filtered_data = [] - for enrollment in enrollment_data: - course_key = CourseKey.from_string(enrollment["course_details"]["course_id"]) - if user_has_role(request.user, CourseStaffRole(course_key)): - filtered_data.append(enrollment) - return Response(filtered_data) + serializer = CourseEnrollmentSerializer(enrollments, many=True) + return Response(serializer.data) + filtered_enrollments = [ + enrollment for enrollment in enrollments + if user_has_role(request.user, CourseStaffRole(enrollment.course_id)) + ] + serializer = CourseEnrollmentSerializer(filtered_enrollments, many=True) + return Response(serializer.data) def post(self, request): # pylint: disable=too-many-statements @@ -929,14 +952,22 @@ def post(self, request): finally: # Assumes that the ecommerce service uses an API key to authenticate. if has_api_key_permissions: - current_enrollment = api.get_enrollment(username, str(course_id)) + try: + current_enrollment_obj = CourseEnrollment.objects.get( + user__username=username, course_id=course_id + ) + actual_mode = current_enrollment_obj.mode + actual_activation = current_enrollment_obj.is_active + except CourseEnrollment.DoesNotExist: + actual_mode = None + actual_activation = None audit_log( "enrollment_change_requested", course_id=str(course_id), requested_mode=mode, - actual_mode=current_enrollment["mode"] if current_enrollment else None, + actual_mode=actual_mode, requested_activation=is_active, - actual_activation=current_enrollment["is_active"] if current_enrollment else None, + actual_activation=actual_activation, user_id=user.id, ) @@ -1087,12 +1118,9 @@ def get(self, request): if not user_email: user_email = request.user.email - enrollments_allowed = CourseEnrollmentAllowed.objects.filter(email=user_email) or [] - serialized_enrollments_allowed = [ - CourseEnrollmentAllowedSerializer(enrollment).data for enrollment in enrollments_allowed - ] - - return Response(status=status.HTTP_200_OK, data=serialized_enrollments_allowed) + enrollments_allowed = CourseEnrollmentAllowed.objects.filter(email=user_email) + serializer = CourseEnrollmentAllowedSerializer(enrollments_allowed, many=True) + return Response(status=status.HTTP_200_OK, data=serializer.data) def post(self, request): """ @@ -1126,23 +1154,24 @@ def post(self, request): - 403: Forbidden, you need to be staff. - 409: Conflict, enrollment allowed already exists. """ - is_bad_request_response, email, course_id = self.check_required_data(request) - auto_enroll = request.data.get("auto_enroll", False) - if is_bad_request_response: - return is_bad_request_response + serializer = CourseEnrollmentAllowedSerializer(data=request.data) + if not serializer.is_valid(): + return Response(status=status.HTTP_400_BAD_REQUEST, data=serializer.errors) try: - enrollment_allowed = CourseEnrollmentAllowed.objects.create( - email=email, course_id=course_id, auto_enroll=auto_enroll - ) + enrollment_allowed = serializer.save() except IntegrityError: return Response( status=status.HTTP_409_CONFLICT, - data={"message": f"An enrollment allowed with email {email} and course {course_id} already exists."}, + data={ + "message": ( + f"An enrollment allowed with email {serializer.validated_data.get('email')} " + f"and course {serializer.validated_data.get('course_id')} already exists." + ) + }, ) - serializer = CourseEnrollmentAllowedSerializer(enrollment_allowed) - return Response(status=status.HTTP_201_CREATED, data=serializer.data) + return Response(status=status.HTTP_201_CREATED, data=CourseEnrollmentAllowedSerializer(enrollment_allowed).data) def delete(self, request): """ @@ -1174,32 +1203,18 @@ def delete(self, request): - 403: Forbidden, you need to be staff. - 404: Not found, the course enrollment allowed doesn't exists. """ - is_bad_request_response, email, course_id = self.check_required_data(request) - if is_bad_request_response: - return is_bad_request_response + serializer = CourseEnrollmentAllowedSerializer(data=request.data) + if not serializer.is_valid(): + return Response(status=status.HTTP_400_BAD_REQUEST, data=serializer.errors) + + email = serializer.validated_data.get("email") + course_id = serializer.validated_data.get("course_id") try: CourseEnrollmentAllowed.objects.get(email=email, course_id=course_id).delete() - return Response( - status=status.HTTP_204_NO_CONTENT, - ) + return Response(status=status.HTTP_204_NO_CONTENT) except ObjectDoesNotExist: return Response( status=status.HTTP_404_NOT_FOUND, data={"message": f"An enrollment allowed with email {email} and course {course_id} doesn't exists."}, ) - - def check_required_data(self, request): - """ - Check if the request has email and course_id. - """ - email = request.data.get("email") - course_id = request.data.get("course_id") - if not email or not course_id: - is_bad_request = Response( - status=status.HTTP_400_BAD_REQUEST, - data={"message": "Please provide a value for 'email' and 'course_id' in the request data."}, - ) - else: - is_bad_request = None - return (is_bad_request, email, course_id)