001/*-
002 * Copyright 2017 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.Arrays;
013import java.util.List;
014
015/**
016 * Class to run over a single dataset with NumPy broadcasting to promote shapes
017 * which have lower rank and outputs to a second dataset
018 * @since 2.1
019 */
020public class BooleanNullIterator extends BooleanIteratorBase {
021
022        /**
023         * @param a
024         * @param o (can be null for new dataset, or a)
025         */
026        public BooleanNullIterator(Dataset a, Dataset o) {
027                this(a, o, false);
028        }
029
030        /**
031         * @param a
032         * @param o (can be null for new dataset, or a)
033         * @param createIfNull if true create the output dataset if that is null
034         * (by default, can create float or complex datasets)
035         */
036        public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull) {
037                this(a, o, createIfNull, false, true);
038        }
039
040        /**
041         * @param a
042         * @param o (can be null for new dataset, or a)
043         * @param createIfNull if true create the output dataset if that is null
044         * @param allowInteger if true, can create integer datasets
045         * @param allowComplex if true, can create complex datasets
046         */
047        @SuppressWarnings("deprecation")
048        public BooleanNullIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) {
049                super(true, a, null, o);
050                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef());
051
052                BroadcastUtils.checkItemSize(a, o);
053
054                maxShape = fullShapes.remove(0);
055
056                oStride = null;
057                if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) {
058                        throw new IllegalArgumentException("Output does not match broadcasted shape");
059                }
060
061                aShape = fullShapes.remove(0);
062
063                int rank = maxShape.length;
064                endrank = rank - 1;
065
066                aDataset = a.reshape(aShape);
067                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
068                if (outputA) {
069                        oStride = aStride;
070                        oDelta = null;
071                        oStep = 0;
072                } else if (o != null) {
073                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
074                        oDelta = new int[rank];
075                        oStep = o.getElementsPerItem();
076                } else if (createIfNull) {
077                        int is = aDataset.getElementsPerItem();
078                        int dt = aDataset.getDType();
079                        if (aDataset.isComplex() && !allowComplex) {
080                                is = 1;
081                                dt = DTypeUtils.getBestFloatDType(dt);
082                        } else if (!aDataset.hasFloatingPointElements() && !allowInteger) {
083                                dt = DTypeUtils.getBestFloatDType(dt);
084                        }
085                        oDataset = DatasetFactory.zeros(is, maxShape, dt);
086                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
087                        oDelta = new int[rank];
088                        oStep = is;
089                } else {
090                        oDelta = null;
091                        oStep = 0;
092                }
093
094                pos = new int[rank];
095                aDelta = new int[rank];
096                for (int j = endrank; j >= 0; j--) {
097                        aDelta[j] = aStride[j] * aShape[j];
098                        if (oDelta != null) {
099                                oDelta[j] = oStride[j] * maxShape[j];
100                        }
101                }
102
103                aStart = aDataset.getOffset();
104                aMax = endrank < 0 ? aStep + aStart : Integer.MIN_VALUE;
105                oStart = oDelta == null ? 0 : oDataset.getOffset();
106                reset();
107        }
108
109        @Override
110        public boolean hasNext() {
111                int j = endrank;
112                for (; j >= 0; j--) {
113                        pos[j]++;
114                        index += aStride[j];
115                        if (oDelta != null) {
116                                oIndex += oStride[j];
117                        }
118                        if (pos[j] >= maxShape[j]) {
119                                pos[j] = 0;
120                                index -= aDelta[j]; // reset these dimensions
121                                if (oDelta != null) {
122                                        oIndex -= oDelta[j];
123                                }
124                        } else {
125                                break;
126                        }
127                }
128                if (j == -1) {
129                        if (endrank >= 0) {
130                                return false;
131                        }
132                        index += aStep;
133                        if (oDelta != null) {
134                                oIndex += oStep;
135                        }
136                }
137                if (outputA) {
138                        oIndex = index;
139                }
140
141                if (index == aMax) {
142                        return false;
143                }
144
145                return true;
146        }
147
148        @Override
149        public void reset() {
150                for (int i = 0; i <= endrank; i++) {
151                        pos[i] = 0;
152                }
153
154                if (endrank >= 0) {
155                        pos[endrank] = -1;
156                        index = aStart - aStride[endrank];
157                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
158                } else {
159                        index = aStart - aStep;
160                        oIndex = oStart - oStep;
161                }
162        }
163}