diff --git a/backend/marketfeed/migrations/0004_remove_portfolio_stocks_portfoliostock.py b/backend/marketfeed/migrations/0004_remove_portfolio_stocks_portfoliostock.py new file mode 100644 index 00000000..0ea04fc0 --- /dev/null +++ b/backend/marketfeed/migrations/0004_remove_portfolio_stocks_portfoliostock.py @@ -0,0 +1,49 @@ +# Generated by Django 4.2 on 2024-12-09 20:28 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("marketfeed", "0003_stock_last_price_stock_last_updated"), + ] + + operations = [ + migrations.RemoveField( + model_name="portfolio", + name="stocks", + ), + migrations.CreateModel( + name="PortfolioStock", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("price_bought", models.DecimalField(decimal_places=2, max_digits=10)), + ("added_at", models.DateTimeField(auto_now_add=True)), + ( + "portfolio", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="portfolio_stocks", + to="marketfeed.portfolio", + ), + ), + ( + "stock", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="marketfeed.stock", + ), + ), + ], + ), + ] diff --git a/backend/marketfeed/migrations/0005_portfolio_stocks.py b/backend/marketfeed/migrations/0005_portfolio_stocks.py new file mode 100644 index 00000000..9f2751d4 --- /dev/null +++ b/backend/marketfeed/migrations/0005_portfolio_stocks.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2 on 2024-12-09 20:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("marketfeed", "0004_remove_portfolio_stocks_portfoliostock"), + ] + + operations = [ + migrations.AddField( + model_name="portfolio", + name="stocks", + field=models.ManyToManyField( + through="marketfeed.PortfolioStock", + to="marketfeed.stock", + verbose_name="list of stocks in the portfolio", + ), + ), + ] diff --git a/backend/marketfeed/migrations/0006_merge_20241210_1341.py b/backend/marketfeed/migrations/0006_merge_20241210_1341.py new file mode 100644 index 00000000..fe8c630d --- /dev/null +++ b/backend/marketfeed/migrations/0006_merge_20241210_1341.py @@ -0,0 +1,13 @@ +# Generated by Django 4.2 on 2024-12-10 13:41 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("marketfeed", "0005_index_currency_index_symbol"), + ("marketfeed", "0005_portfolio_stocks"), + ] + + operations = [] diff --git a/backend/marketfeed/migrations/0007_portfoliostock_quantity.py b/backend/marketfeed/migrations/0007_portfoliostock_quantity.py new file mode 100644 index 00000000..7a9b87ef --- /dev/null +++ b/backend/marketfeed/migrations/0007_portfoliostock_quantity.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2 on 2024-12-10 19:31 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('marketfeed', '0006_merge_20241210_1341'), + ] + + operations = [ + migrations.AddField( + model_name='portfoliostock', + name='quantity', + field=models.PositiveIntegerField(default=1), + ), + ] diff --git a/backend/marketfeed/models.py b/backend/marketfeed/models.py index 3bbcc79c..0367b8ca 100644 --- a/backend/marketfeed/models.py +++ b/backend/marketfeed/models.py @@ -66,13 +66,23 @@ def save(self, **kwargs): return super().save(**kwargs) +class PortfolioStock(models.Model): + portfolio = models.ForeignKey('Portfolio', on_delete=models.CASCADE, related_name='portfolio_stocks') + stock = models.ForeignKey('Stock', on_delete=models.CASCADE) + price_bought = models.DecimalField(max_digits=10, decimal_places=2) + quantity = models.PositiveIntegerField(default=1) + added_at = models.DateTimeField(auto_now_add=True) + + class Portfolio(models.Model): name = models.CharField(max_length=50) - description = models.CharField(max_length = 150) + description = models.CharField(max_length=150) user_id = models.ForeignKey(User, on_delete=models.CASCADE) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(null=True, auto_now=True) - stocks = models.ManyToManyField(Stock, verbose_name="list of stocks in the portfolio") + stocks = models.ManyToManyField( + 'Stock', through='PortfolioStock', verbose_name="list of stocks in the portfolio" + ) class Post(models.Model): diff --git a/backend/marketfeed/serializers.py b/backend/marketfeed/serializers.py index 5b776826..1eee5b86 100644 --- a/backend/marketfeed/serializers.py +++ b/backend/marketfeed/serializers.py @@ -90,30 +90,63 @@ def __init__(self, *args, **kwargs): self.fields['user_id'].required = False -class PortfolioSerializer(serializers.ModelSerializer): - user_id = serializers.PrimaryKeyRelatedField(queryset=User.objects.all()) - stocks = serializers.PrimaryKeyRelatedField(queryset=Stock.objects.all(), many=True) +class PortfolioStockActionSerializer(serializers.Serializer): + portfolio_id = serializers.PrimaryKeyRelatedField(queryset=Portfolio.objects.all()) + stock = serializers.PrimaryKeyRelatedField(queryset=Stock.objects.all()) + price_bought = serializers.DecimalField(max_digits=10, decimal_places=2, required=False) + quantity = serializers.IntegerField(min_value=1, required=False) + + def validate(self, data): + if self.context['request'].method == 'POST' and self.context.get('view').action == 'add_stock': + if 'price_bought' not in data: + raise serializers.ValidationError({'price_bought': 'This field is required for adding a stock.'}) + return data + +class PortfolioStockSerializer(serializers.ModelSerializer): + stock = serializers.PrimaryKeyRelatedField(queryset=Stock.objects.all()) + price_bought = serializers.DecimalField(max_digits=10, decimal_places=2) + quantity = serializers.IntegerField(min_value=1) class Meta: - model = Portfolio - fields = ['id', 'name', 'description', 'user_id', 'created_at', 'updated_at', 'stocks'] - + model = PortfolioStock + fields = ['stock', 'price_bought', 'quantity'] + def __init__(self, *args, **kwargs): - super(PortfolioSerializer, self).__init__(*args, **kwargs) + super(PortfolioStockSerializer, self).__init__(*args, **kwargs) request = self.context.get('request', None) - if request and request.method == 'PUT': - self.fields['name'].required = False - self.fields['description'].required = False - self.fields['user_id'].required = False - self.fields['stocks'].required = False + if request and request.method == 'DELETE': + self.fields['price_bought'].required = False + self.fields['quantity'].required = False - elif request and request.method == 'POST': - self.fields['name'].required = True - self.fields['description'].required = False - self.fields['user_id'].required = True - self.fields['stocks'].required = False + +class PortfolioSerializer(serializers.ModelSerializer): + user_id = serializers.PrimaryKeyRelatedField(read_only=True) + stocks = PortfolioStockSerializer(source='portfolio_stocks', many=True, required=False) + + class Meta: + model = Portfolio + fields = ['id', 'name', 'description', 'user_id', 'created_at', 'updated_at', 'stocks'] + + def create(self, validated_data): + stocks_data = validated_data.pop('portfolio_stocks', []) + portfolio = Portfolio.objects.create(**validated_data) + for stock_data in stocks_data: + PortfolioStock.objects.create(portfolio=portfolio, **stock_data) + return portfolio + + def update(self, instance, validated_data): + if 'portfolio_stocks' in validated_data: + stocks_data = validated_data.pop('portfolio_stocks') + instance.portfolio_stocks.all().delete() + for stock_data in stocks_data: + PortfolioStock.objects.create(portfolio=instance, **stock_data) + + instance.name = validated_data.get('name', instance.name) + instance.description = validated_data.get('description', instance.description) + instance.save() + return instance class CommentSerializer(serializers.ModelSerializer): diff --git a/backend/marketfeed/urls.py b/backend/marketfeed/urls.py index e8d883b6..c6cf8e94 100644 --- a/backend/marketfeed/urls.py +++ b/backend/marketfeed/urls.py @@ -1,6 +1,6 @@ from django.urls import path, include from rest_framework.routers import DefaultRouter -from .views import CurrencyViewSet, StockViewSet, TagViewSet, PortfolioViewSet, PostViewSet, CommentViewSet, IndexViewSet +from .views import CurrencyViewSet, StockViewSet, TagViewSet, PortfolioViewSet, PostViewSet, CommentViewSet, IndexViewSet, PortfolioStockViewSet router = DefaultRouter() router.register(r'currencies', CurrencyViewSet) @@ -10,7 +10,7 @@ router.register(r'posts', PostViewSet) router.register(r'comments', CommentViewSet) router.register(r'indices', IndexViewSet) - +router.register(r'portfolio-stocks', PortfolioStockViewSet, basename='portfolio-stocks') urlpatterns = [ path('', include(router.urls)), diff --git a/backend/marketfeed/views.py b/backend/marketfeed/views.py index 105d3b2c..3e1329d6 100644 --- a/backend/marketfeed/views.py +++ b/backend/marketfeed/views.py @@ -2,6 +2,10 @@ from rest_framework import viewsets, status, permissions from rest_framework.permissions import IsAuthenticated, AllowAny, IsAuthenticatedOrReadOnly from rest_framework.response import Response +from rest_framework.decorators import action +from drf_yasg.utils import swagger_auto_schema +from drf_yasg import openapi +from rest_framework.viewsets import ViewSet from .serializers import * from .models import * from rest_framework.decorators import action @@ -167,11 +171,12 @@ def create(self, request): serializer.save(user_id=request.user) return Response(serializer.data, status=status.HTTP_201_CREATED) - def update(self, request, pk=None): - portfolio = self.get_object() - serializer = self.get_serializer(portfolio, data=request.data) + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) - serializer.save() + self.perform_update(serializer) return Response(serializer.data) def destroy(self, request, pk=None): @@ -179,6 +184,67 @@ def destroy(self, request, pk=None): portfolio.delete() return Response(status=status.HTTP_204_NO_CONTENT) + @action(detail=False, methods=['get'], url_path='portfolios-by-user/(?P[^/.]+)') + def user_portfolios(self, request, user_id=None): + portfolios = self.queryset.filter(user_id=user_id) + serializer = self.get_serializer(portfolios, many=True) + return Response(serializer.data) + + +class PortfolioStockViewSet(ViewSet): + """ + A viewset for adding and removing stocks from a portfolio. + """ + permission_classes = [permissions.IsAuthenticated] + serializer_class = PortfolioStockActionSerializer + + def get_serializer(self, *args, **kwargs): + context = kwargs.pop('context', {}) + context['request'] = self.request + context['view'] = self + return self.serializer_class(*args, context=context, **kwargs) + + @action(detail=False, methods=['post']) + def add_stock(self, request): + serializer = self.get_serializer(data=request.data, context={'action': 'add_stock'}) + serializer.is_valid(raise_exception=True) + + portfolio = serializer.validated_data['portfolio_id'] + stock = serializer.validated_data['stock'] + price_bought = serializer.validated_data['price_bought'] + quantity = serializer.validated_data.get('quantity', 1) + + if PortfolioStock.objects.filter(portfolio=portfolio, stock=stock).exists(): + return Response({'detail': 'This stock is already in the portfolio.'}, status=status.HTTP_400_BAD_REQUEST) + + PortfolioStock.objects.create( + portfolio=portfolio, + stock=stock, + price_bought=price_bought, + quantity=quantity) + + portfolio.stocks.add(stock) + + return Response({'status': 'Stock added to portfolio'}, status=status.HTTP_201_CREATED) + + @action(detail=False, methods=['post']) + def remove_stock(self, request): + serializer = self.get_serializer(data=request.data, context={'action': 'remove_stock'}) + serializer.is_valid(raise_exception=True) + + portfolio = serializer.validated_data['portfolio_id'] + stock = serializer.validated_data['stock'] + + portfolio_stock = PortfolioStock.objects.filter(portfolio=portfolio, stock=stock) + if not portfolio_stock.exists(): + return Response({'detail': 'This stock is not in the portfolio.'}, status=status.HTTP_400_BAD_REQUEST) + + portfolio_stock.delete() + + portfolio.stocks.remove(stock) + + return Response({'status': 'Stock removed from portfolio'}, status=status.HTTP_200_OK) + class PostViewSet(viewsets.ModelViewSet): serializer_class = PostSerializer @@ -225,6 +291,12 @@ def destroy(self, request, *args, **kwargs): post = self.get_object() post.delete() return Response(status=status.HTTP_204_NO_CONTENT) + + @action(detail=False, methods=['get'], url_path='posts-by-user/(?P[^/.]+)') + def user_posts(self, request, user_id=None): + posts = self.queryset.filter(author=user_id) + serializer = self.get_serializer(posts, many=True) + return Response(serializer.data) class CommentViewSet(viewsets.ModelViewSet): @@ -270,7 +342,6 @@ def post_comments(self, request, post_id=None): return Response(serializer.data) - class IndexViewSet(viewsets.ModelViewSet): queryset = Index.objects.all() serializer_class = IndexSerializer