问题描述

设有两个 长度相等的、数值类型的 无序序列a,b;
要求:通过交换a,b中的元素,使得[a中所有元素的和] 与 [b中所有元素的和] 的差最小。


求解思路

在逻辑层面将两个序列,看成一个序列。则问题可以转换成:从一个长度为2n的数组中,选出n个元素,使得这n个元素的和接近[两个数组的所有数据元素的和的一半]。
因此可以通过动态规划来求解:

除了使用动态规划之外,还可以使用回溯算法求出所有可能的解,然后再从中选择最优的解。


代码实现

package cn.timd.DynamicProgramming;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;

public class MinDifference {
    private List<BigDecimal> firstList;
    private List<BigDecimal> secondList;
    private BigDecimal target;
    private int size;
    private int shouldSelectSize;

    public MinDifference(List<BigDecimal> firstList, List<BigDecimal> secondList) {
        if (firstList == null || secondList == null)
            throw new RuntimeException("firstList == null or secondList == null");
        if (firstList.size() == 0 || firstList.size() != secondList.size())
            throw new RuntimeException("invalid firstList or secondList");

        this.firstList = firstList;
        this.secondList = secondList;
        target = getSum(firstList, secondList).divide(new BigDecimal(2)); // average
        size = firstList.size() + secondList.size();
        shouldSelectSize = size / 2;
    }

    private BigDecimal getSum(List<BigDecimal>... lists) {
        BigDecimal sum = new BigDecimal(0);
        for (List<BigDecimal> list: lists)
            for (BigDecimal element: list)
                sum = sum.add(element);
        return sum;
    }

    private BigDecimal getTarget() {
        return target;
    }

    public int getSize() {
        return size;
    }

    public BigDecimal get(int n) {
        if (n < 0 || n >= getSize())
            throw new RuntimeException("invalid index");
        if (n < firstList.size())
            return firstList.get(n);
        return secondList.get(n - firstList.size());
    }

    public int getShouldSelectSize() {
        return shouldSelectSize;
    }

    private List<BigDecimal> dynamicProgramming(int offset, int alreadySelectSize, BigDecimal target) {
        List<BigDecimal> decision = new ArrayList<BigDecimal>();
        if (getSize() - offset + alreadySelectSize == getShouldSelectSize()) {
            for (int n = offset; n < getSize(); ++n)
                decision.add(get(n));
            return decision;
        }

        if (alreadySelectSize == getShouldSelectSize())
            return decision;

        List<BigDecimal> notChooseOffset = dynamicProgramming(offset+1, alreadySelectSize, target);
        List<BigDecimal> chooseOffset = new ArrayList<BigDecimal>();
        chooseOffset.add(get(offset));
        chooseOffset.addAll(dynamicProgramming(offset+1, alreadySelectSize+1, target.subtract(get(offset))));

        BigDecimal notChooseOffsetSum = getSum(notChooseOffset);
        BigDecimal chooseOffsetSum = getSum(chooseOffset);

        if (notChooseOffsetSum.subtract(target).abs().compareTo
                (chooseOffsetSum.subtract(target).abs()) < 0)
            return notChooseOffset;
        return chooseOffset;
    }

    public List<BigDecimal> dynamicProgramming() {
        return dynamicProgramming(0, 0, getTarget());
    }

    private static class Node {
        BigDecimal element;
        List<Node> nextNodes;
        int offset = 0;

        Node(BigDecimal element) {
            this.element = element;
        }

        Node getNextNode() {
            if (nextNodes == null || offset >= nextNodes.size())
                return null;
            return nextNodes.get(offset++);
        }

        void clear() {
            offset = 0;
        }
    }

    private Node createNodes() {
        Node root = new Node(null);

        List<Node> nodes = new ArrayList<Node>();
        for (int i = 0; i < getSize(); ++i)
            nodes.add(new Node(get(i)));
        root.nextNodes = nodes;
        for (int i = 0; i < getSize() - 1; i++)
            nodes.get(i).nextNodes = nodes.subList(i+1, nodes.size());
        return root;
    }

    public List<BigDecimal> backTrace() {
        BigDecimal totalSum = getSum(firstList, secondList);
        List<BigDecimal> result = null;
        BigDecimal currentAbs = null;

        Node root = createNodes();
        Stack<Node> stack = new Stack<Node>();
        stack.push(root);
        Node extensionNode, nextNode;

        while (!stack.empty()) {
            extensionNode = stack.peek();
            nextNode = extensionNode.getNextNode();
            if (nextNode == null) {
                extensionNode = stack.pop();
                if (stack.empty())
                    break;
                extensionNode.clear();
            } else {
                if (stack.size() == getShouldSelectSize()) {
                    List<BigDecimal> list = new ArrayList<BigDecimal>();
                    for (Node node: stack)
                        if (node != root)
                            list.add(node.element);
                    list.add(nextNode.element);

                    BigDecimal abs = getSum(list).multiply(new BigDecimal(2)).subtract(totalSum).abs();
                    if (currentAbs == null || abs.compareTo(currentAbs) < 0) {
                        result = list;
                        currentAbs = abs;
                    }
                } else
                    stack.push(nextNode);
            }
        }
        return result;
    }

    public static void test(String[] args) {
        int[] firstArray = new int[]{12, 19, 8, 3, 16, 13, 4, 7, 1, 2};
        int[] secondArray = new int[]{10, 6, 18, 5, 11, 17, 9, 14, 20, 15};
        List<BigDecimal> firstList = new ArrayList<BigDecimal>();
        List<BigDecimal> secondList = new ArrayList<BigDecimal>();

        for (int element: firstArray)
            firstList.add(new BigDecimal(element));
        for (int element: secondArray)
            secondList.add(new BigDecimal(element));

        MinDifference minDifference = new MinDifference(firstList, secondList);
        for (BigDecimal element: minDifference.backTrace())
            System.out.print(element.toString() + " ");
        System.out.println();

        for (BigDecimal element: minDifference.dynamicProgramming())
            System.out.print(element.toString() + " ");
        System.out.println();
    }

    public static void main(String[] args) {
        test(args);
    }
}

执行结果:

12 19 8 3 16 13 4 7 6 17 
12 19 8 3 16 13 4 7 6 17 

文档